{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.all import *\n",
    "from fastai.optimizer import *\n",
    "from fastai.callback.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner, Metrics, and Basic Callbacks\n",
    "\n",
    "> Basic class for handling the training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You probably want to jump directly to the definition of `Learner`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utils function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#For tests\n",
    "from torch.utils.data import TensorDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):\n",
    "    \"A simple dataset where `x` is random and `y = a*x + b` plus some noise.\"\n",
    "    def get_data(n):\n",
    "        x = torch.randn(int(bs*n))\n",
    "        return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))\n",
    "    train_ds = get_data(n_train)\n",
    "    valid_ds = get_data(n_valid)\n",
    "    device = default_device() if cuda else None\n",
    "    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)\n",
    "    valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)\n",
    "    return DataLoaders(train_dl, valid_dl, device=device)\n",
    "\n",
    "class RegModel(Module):\n",
    "    \"A r\"\n",
    "    def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n",
    "    def forward(self, x): return x*self.a + self.b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "defaults.lr = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def replacing_yield(o, attr, val):\n",
    "    \"Context manager to temporarily replace an attribute\"\n",
    "    old = getattr(o,attr)\n",
    "    try:     yield setattr(o,attr,val)\n",
    "    finally: setattr(o,attr,old)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _A:\n",
    "    def __init__(self, a): self.a = a\n",
    "    @contextmanager\n",
    "    def a_changed(self, v): return replacing_yield(self, 'a', v)\n",
    "\n",
    "a = _A(42)\n",
    "with a.a_changed(32):\n",
    "    test_eq(a.a, 32)\n",
    "test_eq(a.a, 42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def mk_metric(m):\n",
    "    \"Convert `m` to an `AvgMetric`, unless it's already a `Metric`\"\n",
    "    return m if isinstance(m, Metric) else AvgMetric(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the class `Metric` below for more information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def save_model(file, model, opt, with_opt=True, pickle_protocol=2):\n",
    "    \"Save `model` to `file` along with `opt` (if available, and if `with_opt`)\"\n",
    "    if rank_distrib(): return # don't save if child proc\n",
    "    if opt is None: with_opt=False\n",
    "    state = get_model(model).state_dict()\n",
    "    if with_opt: state = {'model': state, 'opt':opt.state_dict()}\n",
    "    torch.save(state, file, pickle_protocol=pickle_protocol)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`file` can be a `Path` object, a string or an opened file object. `pickle_protocol` is passed along to `torch.save`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def load_model(file, model, opt, with_opt=True, device=None, strict=True):\n",
    "    \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n",
    "    distrib_barrier()\n",
    "    if isinstance(device, int): device = torch.device('cuda', device)\n",
    "    elif device is None: device = 'cpu'\n",
    "    state = torch.load(file, map_location=device)\n",
    "    hasopt = set(state)=={'model', 'opt'}\n",
    "    model_state = state['model'] if hasopt else state\n",
    "    get_model(model).load_state_dict(model_state, strict=strict)\n",
    "    if hasopt and with_opt:\n",
    "        try: opt.load_state_dict(state['opt'])\n",
    "        except:\n",
    "            if with_opt: warn(\"Could not load the optimizer state.\")\n",
    "    elif with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`file` can be a `Path` object, a string or an opened file object. If a `device` is passed, the model is loaded on it, otherwise it's loaded on the CPU. \n",
    "\n",
    "If `strict` is `True`, the file must exactly contain weights for every parameter key in `model`, if `strict` is `False`, only the keys that are in the saved model are loaded in `model`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _try_concat(o):\n",
    "    try:    return torch.cat(o)\n",
    "    except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_before_epoch = [event.before_fit, event.before_epoch]\n",
    "_after_epoch  = [event.after_epoch, event.after_fit]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class _ConstantFunc():\n",
    "    \"Returns a function that returns `o`\"\n",
    "    def __init__(self, o): self.o = o\n",
    "    def __call__(self, *args, **kwargs): return self.o"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learner -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train',\n",
    "         'Start Batch Loop', 'before_batch', 'after_pred', 'after_loss', 'before_backward', 'after_backward',\n",
    "         'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',\n",
    "         'after_cancel_train', 'after_train', 'Start Valid', 'before_validate','Start Batch Loop',\n",
    "         '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',\n",
    "         'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',\n",
    "         'after_cancel_fit', 'after_fit']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class Learner():\n",
    "    def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,\n",
    "                 metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,\n",
    "                 moms=(0.95,0.85,0.95)):\n",
    "        path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))\n",
    "        if loss_func is None:\n",
    "            loss_func = getattr(dls.train_ds, 'loss_func', None)\n",
    "            assert loss_func is not None, \"Could not infer loss function from the data, please pass a loss function.\"\n",
    "        self.dls,self.model = dls,model\n",
    "        store_attr(but='dls,model,cbs')\n",
    "        self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()\n",
    "        self.add_cbs([(cb() if isinstance(cb, type) else cb) for cb in L(defaults.callbacks)+L(cbs)])\n",
    "        self(\"after_create\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self): return self._metrics\n",
    "    @metrics.setter\n",
    "    def metrics(self,v): self._metrics = L(v).map(mk_metric)\n",
    "\n",
    "    def _grab_cbs(self, cb_cls): return L(cb for cb in self.cbs if isinstance(cb, cb_cls))\n",
    "    def add_cbs(self, cbs): L(cbs).map(self.add_cb)\n",
    "    def remove_cbs(self, cbs): L(cbs).map(self.remove_cb)\n",
    "    def add_cb(self, cb):\n",
    "        old = getattr(self, cb.name, None)\n",
    "        assert not old or isinstance(old, type(cb)), f\"self.{cb.name} already registered\"\n",
    "        cb.learn = self\n",
    "        setattr(self, cb.name, cb)\n",
    "        self.cbs.append(cb)\n",
    "        return self\n",
    "\n",
    "    def remove_cb(self, cb):\n",
    "        if isinstance(cb, type): self.remove_cbs(self._grab_cbs(cb))\n",
    "        else:\n",
    "            cb.learn = None\n",
    "            if hasattr(self, cb.name): delattr(self, cb.name)\n",
    "            if cb in self.cbs: self.cbs.remove(cb)\n",
    "\n",
    "    @contextmanager\n",
    "    def added_cbs(self, cbs):\n",
    "        self.add_cbs(cbs)\n",
    "        try: yield\n",
    "        finally: self.remove_cbs(cbs)\n",
    "\n",
    "    @contextmanager\n",
    "    def removed_cbs(self, cbs):\n",
    "        self.remove_cbs(cbs)\n",
    "        try: yield self\n",
    "        finally: self.add_cbs(cbs)\n",
    "\n",
    "    def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]\n",
    "\n",
    "    def __call__(self, event_name): L(event_name).map(self._call_one)\n",
    "\n",
    "    def _call_one(self, event_name):\n",
    "        assert hasattr(event, event_name), event_name\n",
    "        [cb(event_name) for cb in sort_by_run(self.cbs)]\n",
    "\n",
    "    def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)\n",
    "    def create_opt(self):\n",
    "        self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)\n",
    "        if not self.wd_bn_bias:\n",
    "            for p in self._bn_bias_state(True ): p['do_wd'] = False\n",
    "        if self.train_bn:\n",
    "            for p in self._bn_bias_state(False): p['force_train'] = True\n",
    "\n",
    "    def _split(self, b):\n",
    "        i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)\n",
    "        self.xb,self.yb = b[:i],b[i:]\n",
    "\n",
    "    def _step(self): self.opt.step()\n",
    "    def _backward(self): self.loss.backward()\n",
    "\n",
    "    def _with_events(self, f, event_type, ex, final=noop):\n",
    "        try:       self(f'before_{event_type}')       ;f()\n",
    "        except ex: self(f'after_cancel_{event_type}')\n",
    "        finally:   self(f'after_{event_type}')        ;final()\n",
    "\n",
    "    def all_batches(self):\n",
    "        self.n_iter = len(self.dl)\n",
    "        for o in enumerate(self.dl): self.one_batch(*o)\n",
    "\n",
    "    def _do_one_batch(self):\n",
    "        self.pred = self.model(*self.xb)\n",
    "        self('after_pred')\n",
    "        if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)\n",
    "        self('after_loss')\n",
    "        if not self.training or not len(self.yb): return\n",
    "        self('before_backward')\n",
    "        self._backward()\n",
    "        self('after_backward')\n",
    "        self._step()\n",
    "        self('after_step')\n",
    "        self.opt.zero_grad()\n",
    "\n",
    "    def one_batch(self, i, b):\n",
    "        self.iter = i\n",
    "        self._split(b)\n",
    "        self._with_events(self._do_one_batch, 'batch', CancelBatchException)\n",
    "\n",
    "    def _do_epoch_train(self):\n",
    "        self.dl = self.dls.train\n",
    "        self._with_events(self.all_batches, 'train', CancelTrainException)\n",
    "\n",
    "    def _do_epoch_validate(self, ds_idx=1, dl=None):\n",
    "        if dl is None: dl = self.dls[ds_idx]\n",
    "        self.dl = dl\n",
    "        with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)\n",
    "\n",
    "    def _do_epoch(self):\n",
    "        self._do_epoch_train()\n",
    "        self._do_epoch_validate()\n",
    "\n",
    "    def _do_fit(self):\n",
    "        for epoch in range(self.n_epoch):\n",
    "            self.epoch=epoch\n",
    "            self._with_events(self._do_epoch, 'epoch', CancelEpochException)\n",
    "\n",
    "    def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):\n",
    "        with self.added_cbs(cbs):\n",
    "            if reset_opt or not self.opt: self.create_opt()\n",
    "            if wd is None: wd = self.wd\n",
    "            if wd is not None: self.opt.set_hypers(wd=wd)\n",
    "            self.opt.set_hypers(lr=self.lr if lr is None else lr)\n",
    "            self.n_epoch = n_epoch\n",
    "            self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)\n",
    "\n",
    "    def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None\n",
    "    def __enter__(self): self(_before_epoch); return self\n",
    "    def __exit__(self, exc_type, exc_value, tb): self(_after_epoch)\n",
    "\n",
    "    def validation_context(self, cbs=None, inner=False):\n",
    "        cms = [self.no_logging(),self.no_mbar()]\n",
    "        if cbs: cms.append(self.added_cbs(cbs))\n",
    "        if not inner: cms.append(self)\n",
    "        return ContextManagers(cms)\n",
    "\n",
    "    def validate(self, ds_idx=1, dl=None, cbs=None):\n",
    "        if dl is None: dl = self.dls[ds_idx]\n",
    "        with self.validation_context(cbs=cbs): self._do_epoch_validate(ds_idx, dl)\n",
    "        return getattr(self, 'final_record', None)\n",
    "\n",
    "    @delegates(GatherPredsCallback.__init__)\n",
    "    def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,\n",
    "                  inner=False, reorder=True, cbs=None, **kwargs):\n",
    "        if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)\n",
    "        else:\n",
    "            try: len(dl)\n",
    "            except TypeError as e:\n",
    "                raise TypeError(\"`dl` is something other than a single `DataLoader` object\")\n",
    "        if reorder and hasattr(dl, 'get_idxs'):\n",
    "            idxs = dl.get_idxs()\n",
    "            dl = dl.new(get_idxs = _ConstantFunc(idxs))\n",
    "        cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)\n",
    "        ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)\n",
    "        if with_loss: ctx_mgrs.append(self.loss_not_reduced())\n",
    "        with ContextManagers(ctx_mgrs):\n",
    "            self._do_epoch_validate(dl=dl)\n",
    "            if act is None: act = getattr(self.loss_func, 'activation', noop)\n",
    "            res = cb.all_tensors()\n",
    "            pred_i = 1 if with_input else 0\n",
    "            if res[pred_i] is not None:\n",
    "                res[pred_i] = act(res[pred_i])\n",
    "                if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))\n",
    "            if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort())\n",
    "            return tuple(res)\n",
    "        self._end_cleanup()\n",
    "\n",
    "    def predict(self, item, rm_type_tfms=None, with_input=False):\n",
    "        dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)\n",
    "        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n",
    "        i = getattr(self.dls, 'n_inp', -1)\n",
    "        inp = (inp,) if i==1 else tuplify(inp)\n",
    "        dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]\n",
    "        dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])\n",
    "        res = dec_targ,dec_preds[0],preds[0]\n",
    "        if with_input: res = (dec_inp,) + res\n",
    "        return res\n",
    "\n",
    "    def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):\n",
    "        if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)\n",
    "        b = dl.one_batch()\n",
    "        _,_,preds = self.get_preds(dl=[b], with_decoded=True)\n",
    "        self.dls.show_results(b, preds, max_n=max_n, **kwargs)\n",
    "\n",
    "    def show_training_loop(self):\n",
    "        indent = 0\n",
    "        for s in _loop:\n",
    "            if s.startswith('Start'): print(f'{\" \"*indent}{s}'); indent += 2\n",
    "            elif s.startswith('End'): indent -= 2; print(f'{\" \"*indent}{s}')\n",
    "            else: print(f'{\" \"*indent} - {s:15}:', self.ordered_cbs(s))\n",
    "\n",
    "    @contextmanager\n",
    "    def no_logging(self): return replacing_yield(self, 'logger', noop)\n",
    "    @contextmanager\n",
    "    def no_mbar(self):    return replacing_yield(self, 'create_mbar', False)\n",
    "\n",
    "    @contextmanager\n",
    "    def loss_not_reduced(self):\n",
    "        if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')\n",
    "        else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))\n",
    "\n",
    "    @delegates(save_model)\n",
    "    def save(self, file, **kwargs):\n",
    "        file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n",
    "        save_model(file, self.model, getattr(self,'opt',None), **kwargs)\n",
    "        return file\n",
    "\n",
    "    @delegates(load_model)\n",
    "    def load(self, file, with_opt=True, device=None, **kwargs):\n",
    "        if device is None and hasattr(self.dls, 'device'): device = self.dls.device\n",
    "        if self.opt is None: self.create_opt()\n",
    "        file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n",
    "        load_model(file, self.model, self.opt, with_opt=with_opt, device=device, **kwargs)\n",
    "        return self\n",
    "    \n",
    "    def to_detach(self,b,cpu=True,gather=True):\n",
    "        return self.dl.to_detach(b,cpu,gather) if hasattr(getattr(self,'dl',None),'to_detach') else to_detach(b,cpu,gather)\n",
    "\n",
    "Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "add_docs(Learner, \"Group together a `model`, some `dls` and a `loss_func` to handle training\",\n",
    "    add_cbs=\"Add `cbs` to the list of `Callback` and register `self` as their learner\",\n",
    "    add_cb=\"Add `cb` to the list of `Callback` and register `self` as their learner\",\n",
    "    remove_cbs=\"Remove `cbs` from the list of `Callback` and deregister `self` as their learner\",\n",
    "    remove_cb=\"Add `cb` from the list of `Callback` and deregister `self` as their learner\",\n",
    "    added_cbs=\"Context manage that temporarily adds `cbs`\",\n",
    "    removed_cbs=\"Context manage that temporarily removes `cbs`\",\n",
    "    ordered_cbs=\"List of `Callback`s, in order, for an `event` in the training loop\",\n",
    "    create_opt=\"Create an optimizer with default hyper-parameters\",\n",
    "    one_batch=\"Train or evaluate `self.model` on batch `(xb,yb)`\",\n",
    "    all_batches=\"Train or evaluate `self.model` on all the batches of `self.dl`\",\n",
    "    fit=\"Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.\",\n",
    "    validate=\"Validate on `dl` with potential new `cbs`.\",\n",
    "    get_preds=\"Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`\",\n",
    "    predict=\"Prediction on `item`, fully decoded, loss function decoded and probabilities\",\n",
    "    validation_context=\"A `ContextManagers` suitable for validation, with optional `cbs`\",\n",
    "    show_results=\"Show some predictions on `ds_idx`-th dataset or `dl`\",\n",
    "    show_training_loop=\"Show each step in the training loop\",\n",
    "    no_logging=\"Context manager to temporarily remove `logger`\",\n",
    "    no_mbar=\"Context manager to temporarily prevent the master progress bar from being created\",\n",
    "    loss_not_reduced=\"A context manager to evaluate `loss_func` with reduction set to none.\",\n",
    "    save=\"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`\",\n",
    "    load=\"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`\",\n",
    "    to_detach=\"Calls `to_detach` if `self.dl` provides a `.to_detach` function otherwise calls global `to_detach`\",         \n",
    "    __call__=\"Call `event_name` for all `Callback`s in `self.cbs`\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h2 id=\"Learner\" class=\"doc_header\"><code>class</code> <code>Learner</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
       "\n",
       "> <code>Learner</code>(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*)\n",
       "\n",
       "Group together a `model`, some `dls` and a `loss_func` to handle training"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`opt_func` will be used to create an optimizer when `Learner.fit` is called, with `lr` as a default learning rate. `splitter` is a function that takes `self.model` and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is `trainable_params`, which returns all trainable parameters of the model.\n",
    "\n",
    "`cbs` is one or a list of `Callback`s  to pass to the `Learner`. `Callback`s are used for every tweak of the training loop. Each `Callback` is registered as an attribute of `Learner` (with camel case). At creation, all the callbacks in `defaults.callbacks` (`TrainEvalCallback`, `Recorder` and `ProgressCallback`) are associated to the `Learner`.\n",
    "\n",
    "`metrics` is an optional list of metrics, that can be either functions or `Metric`s (see below). \n",
    "\n",
    "`path` and `model_dir` are used to save and/or load models. Often `path` will be inferred from `dls`, but you can override it or pass a `Path`  object to `model_dir`. Make sure you can write in `path/model_dir`!\n",
    "\n",
    "`wd` is the default weight decay used when training the model; `moms`, the default momentums used in `Learner.fit_one_cycle`. `wd_bn_bias` controls if weight decay is applied to `BatchNorm` layers and bias. \n",
    "\n",
    "Lastly, `train_bn` controls if `BatchNorm` layers are trained even when they are supposed to be frozen according to the `splitter`. Our empirical experiments have shown that it's the best behavior for those layers in transfer learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch interop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use regular PyTorch functionality for most of the arguments of the `Learner`, although the experience will be smoother with pure fastai objects and you will be able to use the full functionality of the library. The expectation is that the training loop will work smoothly even if you did not use fastai end to end. What you might lose are interpretation objects or showing functionality. The list below explains how to use plain PyTorch objects for all the arguments and what you might lose.\n",
    "\n",
    "The most important is `opt_func`. If you are not using a fastai optimizer, you will need to write a function that wraps your PyTorch optimizer in an `OptimWrapper`. See the [optimizer module](http://docs.fast.ai/optimizer) for more details. This is to ensure the library's schedulers/freeze API work with your code.\n",
    "\n",
    "- `dls` is a `DataLoaders` object, that you can create from standard PyTorch dataloaders. By doing so, you will lose all showing functionality like `show_batch`/`show_results`. You can check the [data block API](http://docs.fast.ai/tutorial.datablock) or the [mid-level data API tutorial](http://docs.fast.ai/tutorial.pets) to learn how to use fastai to gather your data!\n",
    "- `model` is a standard PyTorch model. You can use anyone you like, just make sure it accepts the number of inputs you have in your `DataLoaders` and returns as many outputs as you have targets.\n",
    "- `loss_func` can be any loss function you like. It needs to be one of fastai's if you want to use `Learn.predict` or `Learn.get_preds`, or you will have to implement special methods (see more details after the `BaseLoss` documentation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at the main thing the `Learner` class implements: the training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.fit\" class=\"doc_header\"><code>Learner.fit</code><a href=\"__main__.py#L119\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.fit</code>(**`n_epoch`**, **`lr`**=*`None`*, **`wd`**=*`None`*, **`cbs`**=*`None`*, **`reset_opt`**=*`False`*)\n",
       "\n",
       "Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.fit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Uses `lr` and `wd` if they are provided, otherwise use the defaults values given by the `lr` and `wd` attributes of `Learner`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All the examples use `synth_learner` which is a simple `Learner` training a linear regression model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "def synth_learner(n_train=10, n_valid=2, cuda=False, lr=defaults.lr, **kwargs):\n",
    "    data = synth_dbunch(n_train=n_train,n_valid=n_valid, cuda=cuda)\n",
    "    return Learner(data, RegModel(), loss_func=MSELossFlat(), lr=lr, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Training a few epochs should make the model better\n",
    "learn = synth_learner(lr=0.1)\n",
    "learn(_before_epoch)\n",
    "learn.model = learn.model.cpu()\n",
    "xb,yb = learn.dls.one_batch()\n",
    "init_loss = learn.loss_func(learn.model(xb), yb)\n",
    "learn.fit(10)\n",
    "xb,yb = learn.dls.one_batch()\n",
    "final_loss = learn.loss_func(learn.model(xb), yb)\n",
    "assert final_loss < init_loss, (final_loss,init_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class TestTrainEvalCallback(Callback):\n",
    "    run_after,run_valid = TrainEvalCallback,False\n",
    "    def before_fit(self): \n",
    "        test_eq([self.pct_train,self.train_iter], [0., 0])\n",
    "        self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter\n",
    "    \n",
    "    def before_batch(self): test_eq(next(self.model.parameters()).device, find_device(self.xb))\n",
    "    \n",
    "    def after_batch(self):\n",
    "        assert self.training\n",
    "        test_eq(self.pct_train , self.old_pct_train+1/(self.n_iter*self.n_epoch))\n",
    "        test_eq(self.train_iter, self.old_train_iter+1)\n",
    "        self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter\n",
    "    \n",
    "    def before_train(self):\n",
    "        assert self.training and self.model.training\n",
    "        test_eq(self.pct_train, self.epoch/self.n_epoch)\n",
    "        self.old_pct_train = self.pct_train\n",
    "    \n",
    "    def before_validate(self):\n",
    "        assert not self.training and not self.model.training\n",
    "        \n",
    "learn = synth_learner(cbs=TestTrainEvalCallback)\n",
    "learn.fit(1)\n",
    "#Check order is properly taken into account\n",
    "learn.cbs = L(reversed(learn.cbs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#cuda\n",
    "#Check model is put on the GPU if needed\n",
    "learn = synth_learner(cbs=TestTrainEvalCallback, cuda=True)\n",
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Check wd is not applied on bn/bias when option wd_bn_bias=False\n",
    "class _TstModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n",
    "        self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))\n",
    "        self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3) \n",
    "    def forward(self, x): return x * self.a + self.b\n",
    "    \n",
    "class _PutGrad(Callback):\n",
    "    def after_backward(self):\n",
    "        for p in self.learn.model.tst.parameters():\n",
    "            p.grad = torch.ones_like(p.data)\n",
    "    \n",
    "learn = synth_learner(n_train=5, opt_func = partial(SGD, wd=1, decouple_wd=True), cbs=_PutGrad)\n",
    "learn.model = _TstModel()\n",
    "init = [p.clone() for p in learn.model.tst.parameters()]\n",
    "learn.fit(1, lr=1e-2)\n",
    "end = list(learn.model.tst.parameters())\n",
    "assert not torch.allclose(end[0]-init[0], -0.05 * torch.ones_like(end[0]))\n",
    "for i in [1,2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.one_batch\" class=\"doc_header\"><code>Learner.one_batch</code><a href=\"__main__.py#L96\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.one_batch</code>(**`i`**, **`b`**)\n",
       "\n",
       "Train or evaluate `self.model` on batch `(xb,yb)`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.one_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an internal method called by `Learner.fit`. If passed, `i` is the index of this iteration in the epoch. In training mode, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation. Training or validation is controlled internally by the `TrainEvalCallback` through the `training` attribute.\n",
    "\n",
    "Nothing is returned, but the attributes `x`, `y`, `pred`, `loss` of the `Learner` are set with the proper values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = learn.dls.one_batch()\n",
    "learn.one_batch(0, b)\n",
    "test_eq(learn.x, b[0])\n",
    "test_eq(learn.y, b[1])\n",
    "out = learn.model(learn.x)\n",
    "test_eq(learn.pred, out)\n",
    "test_eq(learn.loss, learn.loss_func(out, b[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "More generally, the following attributes of `Learner` are available and updated during the training loop:\n",
    "- `model`: the model used for training/validation\n",
    "- `data`: the underlying `DataLoaders`\n",
    "- `loss_func`: the loss function used\n",
    "- `opt`: the optimizer used to update the model parameters\n",
    "- `opt_func`: the function used to create the optimizer\n",
    "- `cbs`: the list containing all `Callback`s\n",
    "- `dl`: current `DataLoader` used for iteration\n",
    "- `x`/`xb`: last input drawn from `self.dl` (potentially modified by callbacks). `xb` is always a tuple (potentially with one element) and `x` is detuplified. You can only assign to `xb`.\n",
    "- `y`/`yb`: last target drawn from `self.dl` (potentially modified by callbacks). `yb` is always a tuple (potentially with one element) and `y` is detuplified. You can only assign to `yb`.\n",
    "- `pred`: last predictions from `self.model` (potentially modified by callbacks)\n",
    "- `loss`: last computed loss (potentially modified by callbacks)\n",
    "- `n_epoch`: the number of epochs in this training\n",
    "- `n_iter`: the number of iterations in the current `self.dl`\n",
    "- `epoch`: the current epoch index (from 0 to `n_epoch-1`)\n",
    "- `iter`: the current iteration index in `self.dl` (from 0 to `n_iter-1`)\n",
    "\n",
    "The following attributes are added by `TrainEvalCallback` and should be available unless you went out of your way to remove that callback:\n",
    "\n",
    "- `train_iter`: the number of training iterations done since the beginning of this training\n",
    "- `pct_train`: from 0. to 1., the percentage of training iterations completed\n",
    "- `training`:  flag to indicate if we're in training mode or not\n",
    "\n",
    "The following attribute is added by `Recorder` and should be available unless you went out of your way to remove that callback:\n",
    "\n",
    "- `smooth_loss`: an exponentially-averaged version of the training loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class VerboseCallback(Callback):\n",
    "    \"Callback that prints the name of each event called\"\n",
    "    def __call__(self, event_name):\n",
    "        print(event_name)\n",
    "        super().__call__(event_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class TestOneBatch(VerboseCallback):\n",
    "    def __init__(self, xb, yb, i):\n",
    "        self.save_xb,self.save_yb,self.i = xb,yb,i\n",
    "        self.old_pred,self.old_loss = None,tensor(0.)\n",
    "        \n",
    "    def before_batch(self):\n",
    "        self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()\n",
    "        test_eq(self.iter,    self.i)\n",
    "        test_eq(self.save_xb, *self.xb)\n",
    "        test_eq(self.save_yb, *self.yb)\n",
    "        if hasattr(self.learn, 'pred'): test_eq(self.pred, self.old_pred)\n",
    "    \n",
    "    def after_pred(self):\n",
    "        self.old_pred = self.pred\n",
    "        test_eq(self.pred, self.model.a.data * self.x + self.model.b.data)\n",
    "        test_eq(self.loss, self.old_loss)\n",
    "    \n",
    "    def after_loss(self):\n",
    "        self.old_loss = self.loss\n",
    "        test_eq(self.loss, self.loss_func(self.old_pred, self.save_yb))\n",
    "        for p in self.model.parameters(): \n",
    "            if not hasattr(p, 'grad') or p.grad is not None: test_eq(p.grad, tensor([0.]))\n",
    "    \n",
    "    def after_backward(self):\n",
    "        self.grad_a = (2 * self.x * (self.pred.data - self.y)).mean()\n",
    "        self.grad_b = 2 * (self.pred.data - self.y).mean()\n",
    "        test_close(self.model.a.grad.data, self.grad_a)\n",
    "        test_close(self.model.b.grad.data, self.grad_b)\n",
    "        test_eq(self.model.a.data, self.old_a)\n",
    "        test_eq(self.model.b.data, self.old_b)\n",
    "        \n",
    "    def after_step(self):\n",
    "        test_close(self.model.a.data, self.old_a - self.lr * self.grad_a)\n",
    "        test_close(self.model.b.data, self.old_b - self.lr * self.grad_b)\n",
    "        self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()\n",
    "        test_close(self.model.a.grad.data, self.grad_a)\n",
    "        test_close(self.model.b.grad.data, self.grad_b)\n",
    "    \n",
    "    def after_batch(self):\n",
    "        for p in self.model.parameters(): test_eq(p.grad, tensor([0.]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after_create\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "learn = synth_learner()\n",
    "b = learn.dls.one_batch()\n",
    "learn = synth_learner(cbs=TestOneBatch(*b, 42), lr=1e-2)\n",
    "#Remove train/eval\n",
    "learn.cbs = learn.cbs[1:]\n",
    "#Setup\n",
    "learn.loss,learn.training = tensor(0.),True\n",
    "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n",
    "learn.model.train()\n",
    "batch_events = ['before_batch', 'after_pred', 'after_loss', 'before_backward', 'after_backward', 'after_step', 'after_batch']\n",
    "test_stdout(lambda: learn.one_batch(42, b), '\\n'.join(batch_events))\n",
    "test_stdout(lambda: learn.one_batch(42, b), '\\n'.join(batch_events)) #Check it works for a second batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.all_batches\" class=\"doc_header\"><code>Learner.all_batches</code><a href=\"__main__.py#L79\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.all_batches</code>()\n",
       "\n",
       "Train or evaluate `self.model` on all the batches of `self.dl`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.all_batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after_create\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_train=5, cbs=VerboseCallback())\n",
    "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n",
    "with redirect_stdout(io.StringIO()): \n",
    "    learn(_before_epoch)\n",
    "    learn.epoch,learn.dl = 0,learn.dls.train\n",
    "    learn('before_train')\n",
    "test_stdout(learn.all_batches, '\\n'.join(batch_events * 5))\n",
    "test_eq(learn.train_iter, 5)\n",
    "\n",
    "valid_events = ['before_batch', 'after_pred', 'after_loss', 'after_batch']\n",
    "with redirect_stdout(io.StringIO()): \n",
    "    learn.dl = learn.dls.valid\n",
    "    learn('before_validate')\n",
    "test_stdout(learn.all_batches, '\\n'.join(valid_events * 2))\n",
    "test_eq(learn.train_iter, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after_create\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_train=5, cbs=VerboseCallback())\n",
    "test_stdout(lambda: learn(_before_epoch), 'before_fit\\nbefore_epoch')\n",
    "test_eq(learn.loss, tensor(0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n",
    "learn.epoch = 0\n",
    "test_stdout(lambda: learn._do_epoch_train(), '\\n'.join(['before_train'] + batch_events * 5 + ['after_train']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_stdout(learn._do_epoch_validate, '\\n'.join(['before_validate'] + valid_events * 2+ ['after_validate']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.create_opt\" class=\"doc_header\"><code>Learner.create_opt</code><a href=\"__main__.py#L60\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.create_opt</code>()\n",
       "\n",
       "Create an optimizer with default hyper-parameters"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.create_opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method is called internally to create the optimizer, the hyper-parameters are then adjusted by what you pass to `Learner.fit` or your particular schedulers (see `callback.schedule`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after_create\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_train=5, cbs=VerboseCallback())\n",
    "assert learn.opt is None\n",
    "learn.create_opt()\n",
    "assert learn.opt is not None\n",
    "test_eq(learn.opt.hypers[0]['lr'], learn.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Serializing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.save\" class=\"doc_header\"><code>Learner.save</code><a href=\"__main__.py#L203\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.save</code>(**`file`**, **`with_opt`**=*`True`*, **`pickle_protocol`**=*`2`*)\n",
       "\n",
       "Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.save)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`file` can be a `Path`, a `string` or a buffer. `pickle_protocol` is passed along to `torch.save`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.load\" class=\"doc_header\"><code>Learner.load</code><a href=\"__main__.py#L209\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.load</code>(**`file`**, **`with_opt`**=*`True`*, **`device`**=*`None`*, **`strict`**=*`True`*)\n",
       "\n",
       "Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.load)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`file` can be a `Path`, a `string` or a buffer. Use `device` to load the model/optimizer state on a device different from the one it was saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tempfile.TemporaryDirectory() as d:\n",
    "    learn = synth_learner(path=d)\n",
    "    learn.fit(1)\n",
    "    \n",
    "    #Test save created a file\n",
    "    learn.save('tmp')\n",
    "    assert (Path(d)/'models/tmp.pth').exists()\n",
    "    \n",
    "    #Test load did load the model\n",
    "    learn1 = synth_learner(path=d)\n",
    "    learn1 = learn1.load('tmp')\n",
    "    test_eq(learn.model.a, learn1.model.a)\n",
    "    test_eq(learn.model.b, learn1.model.b)\n",
    "    test_eq(learn.opt.state_dict(), learn1.opt.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:15: UserWarning: Saved filed doesn't contain an optimizer state.\n",
      "  from ipykernel import kernelapp as app\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "#Test load works when the model is saved without opt\n",
    "with tempfile.TemporaryDirectory() as d:\n",
    "    learn = synth_learner(path=d)\n",
    "    learn.fit(1)\n",
    "    learn.save('tmp', with_opt=False)\n",
    "    learn1 = synth_learner(path=d)\n",
    "    learn1 = learn1.load('tmp')\n",
    "    test_eq(learn.model.a, learn1.model.a)\n",
    "    test_eq(learn.model.b, learn1.model.b)\n",
    "    test_ne(learn.opt.state_dict(), learn1.opt.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Callback handling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We only describe the basic functionality linked to `Callback`s here. To learn more about `Callback`s an how to write them, check the [callback.core](http://docs.fast.ai/callback.core) module documentation.\n",
    "\n",
    "Let's first see how the `Callback`s become attributes of `Learner`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test init with callbacks\n",
    "class TstCallback(Callback):\n",
    "    def batch_begin(self): self.learn.a = self.a + 1\n",
    "\n",
    "tst_learn = synth_learner()\n",
    "test_eq(len(tst_learn.cbs), 1)\n",
    "assert isinstance(tst_learn.cbs[0], TrainEvalCallback)\n",
    "assert hasattr(tst_learn, ('train_eval'))\n",
    "\n",
    "tst_learn = synth_learner(cbs=TstCallback())\n",
    "test_eq(len(tst_learn.cbs), 2)\n",
    "assert isinstance(tst_learn.cbs[1], TstCallback)\n",
    "assert hasattr(tst_learn, ('tst'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A name that becomes an existing attribute of the `Learner` will throw an exception (here add_cb is a method of `Learner`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AddCbCallback(Callback): pass\n",
    "test_fail(lambda: synth_learner(cbs=AddCbCallback()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.__call__\" class=\"doc_header\"><code>Learner.__call__</code><a href=\"__main__.py#L53\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.__call__</code>(**`event_name`**)\n",
       "\n",
       "Call `event_name` for all [`Callback`](/callback.core.html#Callback)s in `self.cbs`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.__call__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This how the `Callback`s are called internally. For instance a `VerboseCallback` just prints the event names (can be useful for debugging):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after_create\n",
      "after_fit\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(cbs=VerboseCallback())\n",
    "learn('after_fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.add_cb\" class=\"doc_header\"><code>Learner.add_cb</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.add_cb</code>(**`cb`**)\n",
       "\n",
       "Add `cb` to the list of [`Callback`](/callback.core.html#Callback) and register `self` as their learner"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.add_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cb(TestTrainEvalCallback())\n",
    "test_eq(len(learn.cbs), 2)\n",
    "assert isinstance(learn.cbs[1], TestTrainEvalCallback)\n",
    "test_eq(learn.train_eval.learn, learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.add_cbs\" class=\"doc_header\"><code>Learner.add_cbs</code><a href=\"__main__.py#L22\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.add_cbs</code>(**`cbs`**)\n",
       "\n",
       "Add `cbs` to the list of [`Callback`](/callback.core.html#Callback) and register `self` as their learner"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.add_cbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])\n",
    "test_eq(len(learn.cbs), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.added_cbs\" class=\"doc_header\"><code>Learner.added_cbs</code><a href=\"__main__.py#L39\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.added_cbs</code>(**`cbs`**)\n",
       "\n",
       "Context manage that temporarily adds `cbs`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.added_cbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "test_eq(len(learn.cbs), 1)\n",
    "with learn.added_cbs(TestTrainEvalCallback()):\n",
    "    test_eq(len(learn.cbs), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.ordered_cbs\" class=\"doc_header\"><code>Learner.ordered_cbs</code><a href=\"__main__.py#L51\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.ordered_cbs</code>(**`event`**)\n",
       "\n",
       "List of [`Callback`](/callback.core.html#Callback)s, in order, for an `event` in the training loop"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.ordered_cbs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By order, we mean using the internal ordering of the `Callback`s (see `callback.core` for more information on how it works)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[TrainEvalCallback, TestTrainEvalCallback]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cb(TestTrainEvalCallback())\n",
    "learn.ordered_cbs('before_fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.remove_cb\" class=\"doc_header\"><code>Learner.remove_cb</code><a href=\"__main__.py#L32\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.remove_cb</code>(**`cb`**)\n",
       "\n",
       "Add `cb` from the list of [`Callback`](/callback.core.html#Callback) and deregister `self` as their learner"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.remove_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cb(TestTrainEvalCallback())\n",
    "cb = learn.cbs[1]\n",
    "learn.remove_cb(learn.cbs[1])\n",
    "test_eq(len(learn.cbs), 1)\n",
    "assert cb.learn is None\n",
    "assert not getattr(learn,'test_train_eval',None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`cb` can simply be the class of the `Callback` we want to remove (in which case all instances of that callback are removed)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])\n",
    "learn.remove_cb(TestTrainEvalCallback)\n",
    "test_eq(len(learn.cbs), 1)\n",
    "assert not getattr(learn,'test_train_eval',None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.remove_cbs\" class=\"doc_header\"><code>Learner.remove_cbs</code><a href=\"__main__.py#L23\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.remove_cbs</code>(**`cbs`**)\n",
       "\n",
       "Remove `cbs` from the list of [`Callback`](/callback.core.html#Callback) and deregister `self` as their learner"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.remove_cbs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Elements of `cbs` can either be types of callbacks or actual callbacks of the `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cbs([TestTrainEvalCallback() for _ in range(3)])\n",
    "cb = learn.cbs[1]\n",
    "learn.remove_cbs(learn.cbs[1:])\n",
    "test_eq(len(learn.cbs), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.removed_cbs\" class=\"doc_header\"><code>Learner.removed_cbs</code><a href=\"__main__.py#L45\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.removed_cbs</code>(**`cbs`**)\n",
       "\n",
       "Context manage that temporarily removes `cbs`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.removed_cbs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Elements of `cbs` can either be types of callbacks or actual callbacks of the `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "learn.add_cb(TestTrainEvalCallback())\n",
    "with learn.removed_cbs(learn.cbs[1]):\n",
    "    test_eq(len(learn.cbs), 1)\n",
    "test_eq(len(learn.cbs), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.show_training_loop\" class=\"doc_header\"><code>Learner.show_training_loop</code><a href=\"__main__.py#L186\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.show_training_loop</code>()\n",
       "\n",
       "Show each step in the training loop"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.show_training_loop)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At each step, callbacks are shown in order, which can help debugging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start Fit\n",
      "   - before_fit     : [TrainEvalCallback]\n",
      "  Start Epoch Loop\n",
      "     - before_epoch   : []\n",
      "    Start Train\n",
      "       - before_train   : [TrainEvalCallback]\n",
      "      Start Batch Loop\n",
      "         - before_batch   : []\n",
      "         - after_pred     : []\n",
      "         - after_loss     : []\n",
      "         - before_backward: []\n",
      "         - after_backward : []\n",
      "         - after_step     : []\n",
      "         - after_cancel_batch: []\n",
      "         - after_batch    : [TrainEvalCallback]\n",
      "      End Batch Loop\n",
      "    End Train\n",
      "     - after_cancel_train: []\n",
      "     - after_train    : []\n",
      "    Start Valid\n",
      "       - before_validate: [TrainEvalCallback]\n",
      "      Start Batch Loop\n",
      "         - **CBs same as train batch**: []\n",
      "      End Batch Loop\n",
      "    End Valid\n",
      "     - after_cancel_validate: []\n",
      "     - after_validate : []\n",
      "  End Epoch Loop\n",
      "   - after_cancel_epoch: []\n",
      "   - after_epoch    : []\n",
      "End Fit\n",
      " - after_cancel_fit: []\n",
      " - after_fit      : []\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner()\n",
    "learn.show_training_loop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _before_batch_cb(f, self):\n",
    "    xb,yb = f(self, self.xb, self.yb)\n",
    "    self.learn.xb,self.learn.yb = xb,yb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def before_batch_cb(f):\n",
    "    \"Shortcut for creating a Callback on the `before_batch` event, which takes and returns `xb,yb`\"\n",
    "    return Callback(before_batch=partial(_before_batch_cb, f))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to change the data passed to your model, you will generally want to hook into the `before_batch` event, like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TstCallback(Callback):\n",
    "    def before_batch(self):\n",
    "        self.learn.xb = self.xb + 1000\n",
    "        self.learn.yb = self.yb - 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since that is so common, we provide the `before_batch_cb` decorator to make it easier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@before_batch_cb\n",
    "def cb(self, xb, yb): return xb+1000,yb-1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Control flow testing -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "batch_events  = ['before_batch', 'after_pred', 'after_loss', 'before_backward', 'after_backward', 'after_step', 'after_batch']\n",
    "batchv_events = ['before_batch', 'after_pred', 'after_loss', 'after_batch']\n",
    "train_events  = ['before_train']    + batch_events  + ['after_train']\n",
    "valid_events  = ['before_validate'] + batchv_events + ['after_validate']\n",
    "epoch_events  = ['before_epoch'] + train_events + valid_events + ['after_epoch']\n",
    "cycle_events  = ['before_fit'] + epoch_events + ['after_fit']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_train=1, n_valid=1)\n",
    "test_stdout(lambda: learn.fit(1, cbs=VerboseCallback()), '\\n'.join(cycle_events))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class TestCancelCallback(VerboseCallback):\n",
    "    def __init__(self, cancel_at=event.before_batch, exception=CancelBatchException, train=None):\n",
    "        def _interrupt(): \n",
    "            if train is None or train == self.training: raise exception()\n",
    "        setattr(self, cancel_at, _interrupt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test cancel batch\n",
    "for i,e in enumerate(batch_events[:-1]):\n",
    "    be = batch_events[:i+1] + ['after_cancel_batch', 'after_batch']\n",
    "    bev = be if i <3 else batchv_events\n",
    "    cycle = cycle_events[:3] + be + ['after_train', 'before_validate'] + bev + cycle_events[-3:]\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(cancel_at=e)), '\\n'.join(cycle))\n",
    "\n",
    "#CancelBatchException not caught if thrown in any other event\n",
    "for e in cycle_events:\n",
    "    if e not in batch_events[:-1]:\n",
    "        with redirect_stdout(io.StringIO()):\n",
    "            cb = TestCancelCallback(cancel_at=e)\n",
    "            test_fail(lambda: learn.fit(1, cbs=cb))\n",
    "            learn.remove_cb(cb) #Have to remove it manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test cancel train\n",
    "for i,e in enumerate(['before_train'] + batch_events):\n",
    "    be = batch_events[:i] + (['after_batch'] if i >=1 and i < len(batch_events) else []) \n",
    "    be += ['after_cancel_train', 'after_train']\n",
    "    cycle = cycle_events[:3] + be + ['before_validate'] + batchv_events + cycle_events[-3:]\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelTrainException, True)), '\\n'.join(cycle))\n",
    "\n",
    "#CancelTrainException not caught if thrown in any other event\n",
    "for e in cycle_events:\n",
    "    if e not in ['before_train'] + batch_events[:-1]:\n",
    "        with redirect_stdout(io.StringIO()):\n",
    "            cb = TestCancelCallback(e, CancelTrainException)\n",
    "            test_fail(lambda: learn.fit(1, cbs=cb))\n",
    "            learn.remove_cb(cb) #Have to remove it manually  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test cancel valid\n",
    "for i,e in enumerate(['before_validate'] + batchv_events):\n",
    "    bev = batchv_events[:i] + (['after_batch'] if i >=1 and i < len(batchv_events) else []) + ['after_cancel_validate']\n",
    "    cycle = cycle_events[:3] + batch_events + ['after_train', 'before_validate'] + bev + cycle_events[-3:]\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelValidException, False)), '\\n'.join(cycle))\n",
    "    \n",
    "#CancelValidException not caught if thrown in any other event\n",
    "for e in cycle_events:\n",
    "    if e not in ['before_validate'] + batch_events[:3]:\n",
    "        with redirect_stdout(io.StringIO()):\n",
    "            cb = TestCancelCallback(e, CancelValidException)\n",
    "            test_fail(lambda: learn.fit(1, cbs=cb))\n",
    "            learn.remove_cb(cb) #Have to remove it manually  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test cancel epoch\n",
    "#In train\n",
    "for i,e in enumerate(['before_train'] + batch_events):\n",
    "    be = batch_events[:i] + (['after_batch'] if i >=1 and i<len(batch_events) else []) \n",
    "    cycle = cycle_events[:3] + be + ['after_train', 'after_cancel_epoch'] + cycle_events[-2:]\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelEpochException, True)), '\\n'.join(cycle))\n",
    "\n",
    "#In valid\n",
    "for i,e in enumerate(['before_validate'] + batchv_events):\n",
    "    bev = batchv_events[:i] + (['after_batch'] if i >=1 and i<len(batchv_events) else [])\n",
    "    cycle = cycle_events[:3] + batch_events + ['after_train', 'before_validate'] + bev \n",
    "    cycle += ['after_validate', 'after_cancel_epoch'] + cycle_events[-2:]\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelEpochException, False)), '\\n'.join(cycle))\n",
    "\n",
    "#In begin epoch\n",
    "test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('before_epoch', CancelEpochException, False)), \n",
    "            '\\n'.join(cycle_events[:2] + ['after_cancel_epoch'] + cycle_events[-2:]))\n",
    "\n",
    "#CancelEpochException not caught if thrown in any other event\n",
    "for e in ['before_fit', 'after_epoch', 'after_fit']:\n",
    "    if e not in ['before_validate'] + batch_events[:3]:\n",
    "        with redirect_stdout(io.StringIO()):\n",
    "            cb = TestCancelCallback(e, CancelEpochException)\n",
    "            test_fail(lambda: learn.fit(1, cbs=cb))\n",
    "            learn.remove_cb(cb) #Have to remove it manually  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test cancel fit\n",
    "#In begin fit\n",
    "test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('before_fit', CancelFitException)), \n",
    "            '\\n'.join(['before_fit', 'after_cancel_fit', 'after_fit']))\n",
    "\n",
    "#In begin epoch\n",
    "test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback('before_epoch', CancelFitException, False)), \n",
    "            '\\n'.join(cycle_events[:2] + ['after_epoch', 'after_cancel_fit', 'after_fit']))\n",
    "#In train\n",
    "for i,e in enumerate(['before_train'] + batch_events):\n",
    "    be = batch_events[:i] + (['after_batch'] if i >=1 and i<len(batch_events) else []) \n",
    "    cycle = cycle_events[:3] + be + ['after_train', 'after_epoch', 'after_cancel_fit', 'after_fit']\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelFitException, True)), '\\n'.join(cycle))\n",
    "    \n",
    "#In valid\n",
    "for i,e in enumerate(['before_validate'] + batchv_events):\n",
    "    bev = batchv_events[:i] + (['after_batch'] if i >=1 and i<len(batchv_events) else [])\n",
    "    cycle = cycle_events[:3] + batch_events + ['after_train', 'before_validate'] + bev \n",
    "    cycle += ['after_validate', 'after_epoch', 'after_cancel_fit', 'after_fit']\n",
    "    test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelFitException, False)), '\\n'.join(cycle))\n",
    "    \n",
    "#CancelEpochException not caught if thrown in any other event\n",
    "with redirect_stdout(io.StringIO()):\n",
    "    cb = TestCancelCallback('after_fit', CancelEpochException)\n",
    "    test_fail(lambda: learn.fit(1, cbs=cb))\n",
    "    learn.remove_cb(cb) #Have to remove it manually  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataLoader aware `to_detach` -"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "fastai provides `to_detach` which by default detachs tensor gradients, and gathers (calling `maybe_gather`) tensors from all ranks if running in distributed data parallel (DDP) mode.\n",
    "\n",
    "When running in DDP mode all ranks need to have the same batch size, and `DistributedDL` takes care of padding batches as needed; however when gathering all tensors (e.g. for calculating metrics, inference, etc.) we need to discard the padded items. `DistributedDL` provides a method `to_detach` that removes padding appropriately.\n",
    "\n",
    "Calling `to_detach_from_dl` with `learn` as a learner will attempt to find a `to_detach` method in the learner's last used `DataLoader` `dl` and use that one if found, otherwise it will resort to the vanilla `to_detach`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def to_detach_from_dl(learn:(Learner,NoneType),b:object,cpu:bool=True,gather:bool=True):\n",
    "    return learn.dl.to_detach(b,cpu,gather) if hasattr(getattr(learn,'dl',None),'to_detach') else to_detach(b,cpu,gather)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn = synth_learner()\n",
    "test_eq(to_detach_from_dl(learn,Tensor([123])),Tensor([123]))\n",
    "learn.dl = learn.dls[0]\n",
    "test_eq(to_detach_from_dl(learn,Tensor([123])),Tensor([123]))\n",
    "learn.dl.to_detach = lambda b,cpu,gather: b-100\n",
    "test_eq(to_detach_from_dl(learn,Tensor([123.])),Tensor([23.]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@docs\n",
    "class Metric():\n",
    "    \"Blueprint for defining a metric\"\n",
    "    def reset(self): pass\n",
    "    def accumulate(self, learn): pass\n",
    "    @property\n",
    "    def value(self): raise NotImplementedError\n",
    "\n",
    "    @property\n",
    "    def name(self): return class2attr(self, 'Metric')\n",
    "\n",
    "    _docs = dict(\n",
    "        reset=\"Reset inner state to prepare for new computation\",\n",
    "        name=\"Name of the `Metric`, camel-cased and with Metric removed\",\n",
    "        accumulate=\"Use `learn` to update the state with new results\",\n",
    "        value=\"The value of the metric\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"Metric\" class=\"doc_header\"><code>class</code> <code>Metric</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>Metric</code>()\n",
       "\n",
       "Blueprint for defining a metric"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Metric, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class `AvgMetric`, otherwise you'll need to implement the following methods.\n",
    "\n",
    "> Note: If your <code>Metric</code> has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Metric.reset\" class=\"doc_header\"><code>Metric.reset</code><a href=\"__main__.py#L5\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Metric.reset</code>()\n",
       "\n",
       "Reset inner state to prepare for new computation"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Metric.reset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Metric.accumulate\" class=\"doc_header\"><code>Metric.accumulate</code><a href=\"__main__.py#L6\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Metric.accumulate</code>(**`learn`**)\n",
       "\n",
       "Use `learn` to update the state with new results"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Metric.accumulate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Metric.value\" class=\"doc_header\"><code>Metric.value</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "The value of the metric"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Metric.value, name='Metric.value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Metric.name\" class=\"doc_header\"><code>Metric.name</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "Name of the [`Metric`](/learner.html#Metric), camel-cased and with Metric removed"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Metric.name, name='Metric.name')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _maybe_reduce(val):\n",
    "    if num_distrib()>1:\n",
    "        val = val.clone()\n",
    "        torch.distributed.all_reduce(val, op=torch.distributed.ReduceOp.SUM)\n",
    "        val /= num_distrib()\n",
    "    return val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AvgMetric(Metric):\n",
    "    \"Average the values of `func` taking into account potential different batch sizes\"\n",
    "    def __init__(self, func):  self.func = func\n",
    "    def reset(self):           self.total,self.count = 0.,0\n",
    "    def accumulate(self, learn):\n",
    "        bs = find_bs(learn.yb)\n",
    "        self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs\n",
    "        self.count += bs\n",
    "    @property\n",
    "    def value(self): return self.total/self.count if self.count != 0 else None\n",
    "    @property\n",
    "    def name(self):  return self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"AvgMetric\" class=\"doc_header\"><code>class</code> <code>AvgMetric</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>AvgMetric</code>(**`func`**) :: [`Metric`](/learner.html#Metric)\n",
       "\n",
       "Average the values of `func` taking into account potential different batch sizes"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(AvgMetric, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "tst = AvgMetric(lambda x,y: (x-y).abs().mean())\n",
    "t,u = torch.randn(100),torch.randn(100)\n",
    "tst.reset()\n",
    "for i in range(0,100,25): \n",
    "    learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)\n",
    "    tst.accumulate(learn)\n",
    "test_close(tst.value, (t-u).abs().mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#With varying batch size\n",
    "tst.reset()\n",
    "splits = [0, 30, 50, 60, 100]\n",
    "for i in range(len(splits )-1): \n",
    "    learn.pred,learn.yb = t[splits[i]:splits[i+1]],(u[splits[i]:splits[i+1]],)\n",
    "    tst.accumulate(learn)\n",
    "test_close(tst.value, (t-u).abs().mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AvgLoss(Metric):\n",
    "    \"Average the losses taking into account potential different batch sizes\"\n",
    "    def reset(self):           self.total,self.count = 0.,0\n",
    "    def accumulate(self, learn):\n",
    "        bs = find_bs(learn.yb)\n",
    "        self.total += learn.to_detach(learn.loss.mean())*bs\n",
    "        self.count += bs\n",
    "    @property\n",
    "    def value(self): return self.total/self.count if self.count != 0 else None\n",
    "    @property\n",
    "    def name(self):  return \"loss\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"AvgLoss\" class=\"doc_header\"><code>class</code> <code>AvgLoss</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>AvgLoss</code>() :: [`Metric`](/learner.html#Metric)\n",
       "\n",
       "Average the losses taking into account potential different batch sizes"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(AvgLoss, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = AvgLoss()\n",
    "t = torch.randn(100)\n",
    "tst.reset()\n",
    "for i in range(0,100,25): \n",
    "    learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()\n",
    "    tst.accumulate(learn)\n",
    "test_close(tst.value, t.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#With varying batch size\n",
    "tst.reset()\n",
    "splits = [0, 30, 50, 60, 100]\n",
    "for i in range(len(splits )-1): \n",
    "    learn.yb,learn.loss = t[splits[i]:splits[i+1]],t[splits[i]:splits[i+1]].mean()\n",
    "    tst.accumulate(learn)\n",
    "test_close(tst.value, t.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AvgSmoothLoss(Metric):\n",
    "    \"Smooth average of the losses (exponentially weighted with `beta`)\"\n",
    "    def __init__(self, beta=0.98): self.beta = beta\n",
    "    def reset(self):               self.count,self.val = 0,tensor(0.)\n",
    "    def accumulate(self, learn):\n",
    "        self.count += 1\n",
    "        self.val = torch.lerp(to_detach(learn.loss.mean(), gather=False), self.val, self.beta)\n",
    "    @property\n",
    "    def value(self): return self.val/(1-self.beta**self.count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"AvgSmoothLoss\" class=\"doc_header\"><code>class</code> <code>AvgSmoothLoss</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>AvgSmoothLoss</code>(**`beta`**=*`0.98`*) :: [`Metric`](/learner.html#Metric)\n",
       "\n",
       "Smooth average of the losses (exponentially weighted with `beta`)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(AvgSmoothLoss, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = AvgSmoothLoss()\n",
    "t = torch.randn(100)\n",
    "tst.reset()\n",
    "val = tensor(0.)\n",
    "for i in range(4): \n",
    "    learn.loss = t[i*25:(i+1)*25].mean()\n",
    "    tst.accumulate(learn)\n",
    "    val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)\n",
    "    test_close(val/(1-0.98**(i+1)), tst.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ValueMetric(Metric):\n",
    "    \"Use to include a pre-calculated metric value (for instance calculated in a `Callback`) and returned by `func`\"\n",
    "    def __init__(self, func, metric_name=None): store_attr('func, metric_name')\n",
    "\n",
    "    @property\n",
    "    def value(self): return self.func()\n",
    "\n",
    "    @property\n",
    "    def name(self): return self.metric_name if self.metric_name else self.func.__name__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"ValueMetric\" class=\"doc_header\"><code>class</code> <code>ValueMetric</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>ValueMetric</code>(**`func`**, **`metric_name`**=*`None`*) :: [`Metric`](/learner.html#Metric)\n",
       "\n",
       "Use to include a pre-calculated metric value (for instance calculated in a [`Callback`](/callback.core.html#Callback)) and returned by `func`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ValueMetric, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def metric_value_fn(): return 5e-3\n",
    "\n",
    "vm = ValueMetric(metric_value_fn, 'custom_value_metric')\n",
    "test_eq(vm.value, 5e-3)\n",
    "test_eq(vm.name, 'custom_value_metric')\n",
    "\n",
    "vm = ValueMetric(metric_value_fn)\n",
    "test_eq(vm.name, 'metric_value_fn')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recorder --"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastprogress.fastprogress import format_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _maybe_item(t):\n",
    "    t = t.value\n",
    "    try: return t.item()\n",
    "    except: return t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Recorder(Callback):\n",
    "    \"Callback that registers statistics (lr, loss and metrics) during training\"\n",
    "    _stateattrs=('lrs','iters','losses','values')\n",
    "    remove_on_fetch,run_after = True,TrainEvalCallback\n",
    "\n",
    "    def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):\n",
    "        store_attr('add_time,train_metrics,valid_metrics')\n",
    "        self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)\n",
    "\n",
    "    def before_fit(self):\n",
    "        \"Prepare state for training\"\n",
    "        self.lrs,self.iters,self.losses,self.values = [],[],[],[]\n",
    "        names = self.metrics.attrgot('name')\n",
    "        if self.train_metrics and self.valid_metrics:\n",
    "            names = L('loss') + names\n",
    "            names = names.map('train_{}') + names.map('valid_{}')\n",
    "        elif self.valid_metrics: names = L('train_loss', 'valid_loss') + names\n",
    "        else: names = L('train_loss') + names\n",
    "        if self.add_time: names.append('time')\n",
    "        self.metric_names = 'epoch'+names\n",
    "        self.smooth_loss.reset()\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Update all metrics and records lr and smooth loss in training\"\n",
    "        if len(self.yb) == 0: return\n",
    "        mets = self._train_mets if self.training else self._valid_mets\n",
    "        for met in mets: met.accumulate(self.learn)\n",
    "        if not self.training: return\n",
    "        self.lrs.append(self.opt.hypers[-1]['lr'])\n",
    "        self.losses.append(self.smooth_loss.value)\n",
    "        self.learn.smooth_loss = self.smooth_loss.value\n",
    "\n",
    "    def before_epoch(self):\n",
    "        \"Set timer if `self.add_time=True`\"\n",
    "        self.cancel_train,self.cancel_valid = False,False\n",
    "        if self.add_time: self.start_epoch = time.time()\n",
    "        self.log = L(getattr(self, 'epoch', 0))\n",
    "\n",
    "    def before_train   (self): self._train_mets[1:].map(Self.reset())\n",
    "    def before_validate(self): self._valid_mets.map(Self.reset())\n",
    "    def after_train   (self): self.log += self._train_mets.map(_maybe_item)\n",
    "    def after_validate(self): self.log += self._valid_mets.map(_maybe_item)\n",
    "    def after_cancel_train(self):    self.cancel_train = True\n",
    "    def after_cancel_validate(self): self.cancel_valid = True\n",
    "\n",
    "    def after_epoch(self):\n",
    "        \"Store and log the loss/metric values\"\n",
    "        self.learn.final_record = self.log[1:].copy()\n",
    "        self.values.append(self.learn.final_record)\n",
    "        if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))\n",
    "        self.logger(self.log)\n",
    "        self.iters.append(self.smooth_loss.count)\n",
    "\n",
    "    @property\n",
    "    def _train_mets(self):\n",
    "        if getattr(self, 'cancel_train', False): return L()\n",
    "        return L(self.smooth_loss) + (self.metrics if self.train_metrics else L())\n",
    "\n",
    "    @property\n",
    "    def _valid_mets(self):\n",
    "        if getattr(self, 'cancel_valid', False): return L()\n",
    "        return (L(self.loss) + self.metrics if self.valid_metrics else L())\n",
    "\n",
    "    def plot_loss(self, skip_start=5, with_valid=True):\n",
    "        plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')\n",
    "        if with_valid:\n",
    "            idx = (np.array(self.iters)<skip_start).sum()\n",
    "            plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(1), label='valid')\n",
    "            plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "add_docs(Recorder,\n",
    "         before_train = \"Reset loss and metrics state\",\n",
    "         after_train = \"Log loss and metric values on the training set (if `self.training_metrics=True`)\",\n",
    "         before_validate = \"Reset loss and metrics state\",\n",
    "         after_validate = \"Log loss and metric values on the validation set\",\n",
    "         after_cancel_train = \"Ignore training metrics for this epoch\",\n",
    "         after_cancel_validate = \"Ignore validation metrics for this epoch\",\n",
    "         plot_loss = \"Plot the losses from `skip_start` and onward\")\n",
    "\n",
    "if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, metrics are computed on the validation set only, although that can be changed by adjusting `train_metrics` and `valid_metrics`. `beta` is the weight used to compute the exponentially weighted average of the losses (which gives the `smooth_loss` attribute to `Learner`).\n",
    "\n",
    "The `logger` attribute of a `Learner` determines what happens to those metrics. By default, it just print them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test printed output\n",
    "def tst_metric(out, targ): return F.mse_loss(out, targ)\n",
    "learn = synth_learner(n_train=5, metrics=tst_metric)\n",
    "# pat = r\"[tensor\\(\\d.\\d*\\), tensor\\(\\d.\\d*\\), tensor\\(\\d.\\d*\\), 'dd:dd']\"\n",
    "pat = r\"\\[\\d, \\d+.\\d+, \\d+.\\d+, \\d+.\\d+, '\\d\\d:\\d\\d'\\]\"\n",
    "test_stdout(lambda: learn.fit(1), pat, regex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "class TestRecorderCallback(Callback):\n",
    "    run_after=Recorder\n",
    "    \n",
    "    def before_fit(self): \n",
    "        self.train_metrics,self.add_time = self.recorder.train_metrics,self.recorder.add_time\n",
    "        self.beta = self.recorder.smooth_loss.beta\n",
    "        for m in self.metrics: assert isinstance(m, Metric)\n",
    "        test_eq(self.recorder.smooth_loss.val, 0.)\n",
    "        #To test what the recorder logs, we use a custom logger function.\n",
    "        self.learn.logger = self.test_log\n",
    "        self.old_smooth,self.count = tensor(0.),0\n",
    "    \n",
    "    def after_batch(self):\n",
    "        if self.training:\n",
    "            self.count += 1\n",
    "            test_eq(len(self.recorder.lrs), self.count)\n",
    "            test_eq(self.recorder.lrs[-1], self.opt.hypers[-1]['lr'])\n",
    "            test_eq(len(self.recorder.losses), self.count)\n",
    "            smooth = (1 - self.beta**(self.count-1)) * self.old_smooth * self.beta + self.loss * (1-self.beta)\n",
    "            smooth /= 1 - self.beta**self.count\n",
    "            test_close(self.recorder.losses[-1], smooth, eps=1e-4)\n",
    "            test_close(self.smooth_loss, smooth, eps=1e-4)\n",
    "            self.old_smooth = self.smooth_loss\n",
    "        self.bs += find_bs(self.yb)\n",
    "        if not self.training: test_eq(self.recorder.loss.count, self.bs)\n",
    "        if self.train_metrics or not self.training: \n",
    "            for m in self.metrics: test_eq(m.count, self.bs)\n",
    "        self.losses.append(self.loss.detach().cpu())\n",
    "    \n",
    "    def before_epoch(self): \n",
    "        if self.add_time: self.start_epoch = time.time()\n",
    "        self.log = [self.epoch]\n",
    "\n",
    "    def before_train(self):\n",
    "        self.bs = 0\n",
    "        self.losses = []\n",
    "        for m in self.recorder._train_mets: test_eq(m.count, self.bs)\n",
    "            \n",
    "    def after_train(self):\n",
    "        mean = tensor(self.losses).mean()\n",
    "        self.log += [self.smooth_loss, mean] if self.train_metrics else [self.smooth_loss]\n",
    "        test_close(self.log, self.recorder.log)\n",
    "        self.losses = []\n",
    "    \n",
    "    def before_validate(self):\n",
    "        self.bs = 0\n",
    "        self.losses = []\n",
    "        for m in [self.recorder.loss] + self.metrics: test_eq(m.count, self.bs)\n",
    "    \n",
    "    def test_log(self, log):\n",
    "        res = tensor(self.losses).mean()\n",
    "        self.log += [res, res]\n",
    "        if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))\n",
    "        test_close(log[:-1], self.log[:-1])\n",
    "        test_eq(log[-1], self.log[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_train=5, metrics = tst_metric, cbs = TestRecorderCallback)\n",
    "learn.fit(1)\n",
    "test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric', 'time'])\n",
    "\n",
    "learn = synth_learner(n_train=5, metrics = tst_metric, cbs = TestRecorderCallback)\n",
    "learn.recorder.train_metrics=True\n",
    "learn.fit(1)\n",
    "test_eq(learn.recorder.metric_names, \n",
    "        ['epoch', 'train_loss', 'train_tst_metric', 'valid_loss', 'valid_tst_metric', 'time'])\n",
    "\n",
    "learn = synth_learner(n_train=5, metrics = tst_metric, cbs = TestRecorderCallback)\n",
    "learn.recorder.add_time=False\n",
    "learn.fit(1)\n",
    "test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 7.783848762512207, 7.449564456939697, 7.449564456939697, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "#Test numpy metric\n",
    "def tst_metric_np(out, targ): return F.mse_loss(out, targ).numpy()\n",
    "learn = synth_learner(n_train=5, metrics=tst_metric_np)\n",
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Internals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.before_fit\" class=\"doc_header\"><code>Recorder.before_fit</code><a href=\"__main__.py#L11\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.before_fit</code>()\n",
       "\n",
       "Prepare state for training"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.before_fit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.before_epoch\" class=\"doc_header\"><code>Recorder.before_epoch</code><a href=\"__main__.py#L34\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.before_epoch</code>()\n",
       "\n",
       "Set timer if `self.add_time=True`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.before_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.before_validate\" class=\"doc_header\"><code>Recorder.before_validate</code><a href=\"__main__.py#L41\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.before_validate</code>()\n",
       "\n",
       "Reset loss and metrics state"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.before_validate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.after_batch\" class=\"doc_header\"><code>Recorder.after_batch</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.after_batch</code>()\n",
       "\n",
       "Update all metrics and records lr and smooth loss in training"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.after_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.after_epoch\" class=\"doc_header\"><code>Recorder.after_epoch</code><a href=\"__main__.py#L47\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.after_epoch</code>()\n",
       "\n",
       "Store and log the loss/metric values"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.after_epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Recorder.plot_loss\" class=\"doc_header\"><code>Recorder.plot_loss</code><a href=\"__main__.py#L65\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Recorder.plot_loss</code>(**`skip_start`**=*`5`*, **`with_valid`**=*`True`*)\n",
       "\n",
       "Plot the losses from `skip_start` and onward"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Recorder.plot_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAtjUlEQVR4nO3deXxU9b3/8deHrCSEAElYQxJEdggBQgjiglUQUMAFIUBoe9srxaVXvK3W1mur93a3i3sp9nrtD5FVEEVEWitFKIsBAwQQ2UmIkAASlgTI8vn9kcGmcQIzZJIzmfk8H488mJzzPTPvOc58PDnbR1QVY4wxgauZ0wGMMcY0LCv0xhgT4KzQG2NMgLNCb4wxAc4KvTHGBLhQpwO4Ex8frykpKU7HMMaYJmPz5s3HVTXB3Ty/LPQpKSnk5OQ4HcMYY5oMETlU1zzbdWOMMQHOCr0xxgQ4K/TGGBPg/HIfvTHGeKO8vJyCggLOnz/vdJQGFxkZSWJiImFhYR4vY4XeGNPkFRQUEBMTQ0pKCiLidJwGo6qcOHGCgoICunTp4vFytuvGGNPknT9/nri4uIAu8gAiQlxcnNd/uVihN8YEhEAv8pdczfsMmEJ/vrySV9bsZ/2+E05HMcYYvxIwhT6kmfDKR/uZvWaf01GMMUHm1KlTvPzyy14vN2bMGE6dOuX7QLUETKEPC2lGVkYSqz8rJv9kqdNxjDFBpK5CX1lZednlVqxYQatWrRoo1T8FTKEHmJzRmWYizN142Okoxpgg8vjjj7Nv3z7S0tIYPHgwN998M1OmTKFfv34A3HnnnQwaNIg+ffowe/bsL5dLSUnh+PHjHDx4kF69enHffffRp08fRo4cSVlZmc/yBdTplR1im3Nrr7YszMnnkRHdiAgNcTqSMaaRPf3ODnYWnvbpc/bu2JKfjO1T5/xf/vKX5OXlkZuby+rVq7n99tvJy8v78hTIV199lTZt2lBWVsbgwYO55557iIuL+5fn2LNnD/PmzeOVV15h4sSJvPnmm2RnZ/skf0Bt0QNkZyZz8txF3tt+1OkoxpgglZGR8S/nuT///PP079+fzMxM8vPz2bNnz1eW6dKlC2lpaQAMGjSIgwcP+ixPQG3RAwzrGk+X+Ghe33CIOwd0cjqOMaaRXW7Lu7FER0d/+Xj16tX89a9/Zf369URFRTF8+HC358FHRER8+TgkJMSnu26uuEUvIj1EJLfGz2kRmVlrTKyIvCMiW0Vkh4j8W415o0Rkt4jsFZHHfZa8Ds2aCVOHJJFz6At2fe7bP9+MMcadmJgYzpw543ZeSUkJrVu3Jioqik8//ZQNGzY0cjoPCr2q7lbVNFVNAwYBpcDSWsMeBHaqan9gOPBbEQkXkRDgJWA00BuYLCK9fZjfrQmDEokIbcbrG+q8PbMxxvhMXFwcw4YNo2/fvjz66KP/Mm/UqFFUVFSQmprKk08+SWZmZqPn83bXzS3APlWtXUEViJHqS7ZaACeBCmAIsFdV9wOIyHxgPLCzXqmvoFVUOGP7d2TpJ0d4fHRPYiI9v/mPMcZcjTfeeMPt9IiICN577z238y7th4+PjycvL+/L6d///vd9ms3bg7FZwDw3018EegGFwHbgYVWtAjoB+TXGFbimfYWITBeRHBHJKS4u9jLWV03LTKb0YiVvfXKk3s9ljDFNmceFXkTCgXHAIjezbwNygY5AGvCiiLQE3N2UQd09v6rOVtV0VU1PSHDb9tAr/Tu3ol+nWOZsOISq25c0xpig4M0W/Whgi6oeczPv34AlWm0vcADoSfUWfOca4xKp3upvFNMyk/ns2Fk+PvhFY72kMcb4HW8K/WTc77YBOEz1/ntEpB3QA9gPfAx0E5Eurr8IsoC3rz6ud8b270jLyFDm2EFZY0wQ86jQi0gUMAJYUmPaDBGZ4fr1f4DrRGQ78AHwA1U9rqoVwEPA+8AuYKGq7vDlG7ic5uEh3DMokZV5n1N85kJjvawxxvgVj866UdVSIK7WtFk1HhcCI+tYdgWwoh4Z6yU7M5n/W3eQhTn5PHjztU7FMMYYxwTcLRBq65rQguu6xvHGxsNUVtlBWWOMf2jRogUAhYWFTJgwwe2Y4cOHk5OTU+/XCvhCD9UHZY+cKuPDT4ucjmJM0KiorLIz3jzQsWNHFi9e3KCvERSF/tbe7WgbE8HrG+2grDGN5ecrPuU7czZTUVnldJRG8YMf/OBf7kn/1FNP8fTTT3PLLbcwcOBA+vXrx7Jly76y3MGDB+nbty8AZWVlZGVlkZqayqRJk3x2v5uAu6mZO2EhzZickcTzf9vD4ROlJMVFOR3JmID20Z5iXl13gG8MTSY0pJG3J997HI5u9+1ztu8Ho3952SFZWVnMnDmTBx54AICFCxeycuVKHnnkEVq2bMnx48fJzMxk3LhxdfZ9/cMf/kBUVBTbtm1j27ZtDBw40Cfxg2KLHmByRlJ1U5JNtlVvTEP64txFvr9oK9e2bcEPx/RyOk6jGTBgAEVFRRQWFrJ161Zat25Nhw4d+NGPfkRqaiq33norR44c4dgxd5ciVVuzZs2X96BPTU0lNTXVJ9mCYoseoH1sJCN6tWNRTgGP3NqdyDBrSmKMr6kqP1q6nZPnLvK/3xjszPfsClveDWnChAksXryYo0ePkpWVxdy5cykuLmbz5s2EhYWRkpLi9hbFNdW1tV8fQbNFDzWakuR97nQUYwLS4s0FvJd3lO+N7EHfTrFOx2l0WVlZzJ8/n8WLFzNhwgRKSkpo27YtYWFhfPjhhxw6dPk9CjfeeCNz584FIC8vj23btvkkV1AV+uu6xnFNfDSvb7Cessb42uETpTz19g6GdGnDfTdc43QcR/Tp04czZ87QqVMnOnTowNSpU8nJySE9PZ25c+fSs2fPyy5///33c/bsWVJTU/n1r39NRkaGT3IFza4bqG5KMmVIEj99dxc7C0/Tu2NLpyMZExAqKqt4ZGEuzZoJv5uURkgz3+9+aCq2b//ngeD4+HjWr1/vdtzZs2eB6gbhl25R3Lx5c+bPn+/zTEG1RQ9w76DORIY1s1MtjfGhl1fvY/OhL/jpnX3p1Kq503FMLUFX6GOjwhib2pG3PjnCmfPlTscxpsnLzT/Fcx/sYXxaR8anWZ9mfxR0hR5g2tDqpiRLrSmJMfVy7kIFM+d/QvuWkfz3+L6OZgmWq3Cv5n0GZaFPTWxFamIsc9ZbUxJj6uOn7+7k0MlSfjuxP7HNnWvZGRkZyYkTJwL++6yqnDhxgsjISK+WC6qDsTVlZybz2OJtbDpwkiHXxF15AWPMv1i14yjzNuUz46auZDr8HUpMTKSgoABftCH1d5GRkSQmJnq1TNAW+rGpHfnp8p3M2XDICr0xXio6c57Hl2ynT8eW/OeI7k7HISwsjC5dujgdw28F5a4bqG5Kcm96Z97fcZSiM5e/Us0Y80+qyqOLtnHuQgXPZaURHhq0ZaTJCOr/QlOHJFFeqSz8ON/pKMY0GXM2HOLvnxXzxO29uLZtjNNxjAeCutBfk9CC66+Nt6Ykxnhob9EZfvbuLm7qnsC0zGSn4xgPXbHQi0gPEcmt8XNaRGbWGvNojfl5IlIpIm1c8w6KyHbXvPq3SvGx7MwkCkvO8zdrSmLMZV2sqOLh+blER4TyzL2pDXLzLdMwrngwVlV3A2kAIhICHAGW1hrzDPCMa8xY4BFVPVljyM2qetxHmX3q1l7taNcygtc3HGJE73ZOxzHGb/3uL5+xo/A0s6cNom2Md6f3GWd5u+vmFmCfql7u/gGTgXlXH6lxhbqakvz9s2IOnTjndBxj/NKG/Sf445p9TM7ozMg+7Z2OY7zkbaHP4jJFXESigFHAmzUmK7BKRDaLyPTLLDtdRHJEJKexz4XNGpxESDPhjY12V0tjaispK+d7C7eS3CaK/7q9t9NxzFXwuNCLSDgwDlh0mWFjgXW1dtsMU9WBwGjgQRG50d2CqjpbVdNVNT0hIcHTWD7RPjaSkb3bsTAnn/PllY362sb4ux8vy+Po6fM8mzWA6IigvfSmSfNmi340sEVV6+6D5WaLX1ULXf8WUb1v3zc3WPax7MxkvigtZ8V2a0pizCXLco+wLLeQh2/pRlrnVk7HMVfJm0J/2X3vIhIL3AQsqzEtWkRiLj0GRgJ5Vxe1YV3XNY5rEqKZs8FuX2wMwJFTZfzXW3kMSm7NA8O7Oh3H1INHhd61730EsKTGtBkiMqPGsLuAVapa84hmO2CtiGwFNgHvqurK+sf2PRFh6pBkPjl8ih2FJU7HMcZRlVXKfy7IpapK+f3ENEJDgvqSmybPo/96qlqqqnGqWlJj2ixVnVXj99dUNavWcvtVtb/rp4+q/sx30X1vwsDE6qYk1mrQBLlXPtrPxgMneWpcH5LiopyOY+rJ/jddQ2xUGOP6VzclOW1NSUyQyjtSwm9X7WZ03/ZMGOTdXRKNf7JCX8u0zBTKyitZusWakpjgU3axkpkLcmkTHc7P7+pnV78GCCv0tfRLjKV/YixzNlhTEhN8fvneLvYWneU39/andXS403GMj1ihdyM7M5m9RWfZeODklQcbEyA+3F3En9cf4lvDunBDt8a9lsU0LCv0bozt35HY5mF2qqUJGifOXuCxxdvo0S6Gx0b1cDqO8TEr9G5EhoVw76BE3s+zpiQm8Kkqjy/ZTklpOc9mpREZFuJ0JONjVujrMDUzmYoqZcEma0piAtuCj/P5y85jPDaqB706tHQ6jmkAVujr0CU+mhu6xTNv02EqKqucjmNMgzhw/BxPv7OT67rG8a1h1nM1UFmhv4ypQ5KtKYkJWOWVVcxckEt4aDN+O7E/zZrZqZSBygr9Zdzaqy3tW0byut2+2ASgF/62l635p/j5Xf3oENvc6TimAVmhv4xLTUnWfFbMwePWlMQEjs2HvuDFv+3h7oGduD21g9NxTAOzQn8FWRmdCW0mvLHJtupNYDh7oYJHFuTSsVVznh7Xx+k4phFYob+Cdi0jGdnHmpKYwPH02zso+KKU309KIyYyzOk4phFYofdA9pBkTpWW8+42a0pimrb3tn/Oos0FPDD8WgantHE6jmkkVug9MNSakpgAcLTkPD9cup3UxFgevrWb03FMI7JC7wERIXtIMrn5p8g7Yk1JTNNTVaU8ungrF8qreHZSGmHWSCSo2H9tD90z6FJTEtuqN03Pa/84yEd7jvNfd/TimoQWTscxjcwKvYdim4cxvn8nluUWUlJmTUlM07H76Bl+ufJTbunZlikZSU7HMQ64YqEXkR4iklvj57SIzKw15tEa8/NEpFJE2rjmjRKR3SKyV0Qeb6D30SimDU2mrLySJVsKnI5ijEcuVFTy8PxPaBkZyq8mpFojkSB1xUKvqrtVNU1V04BBQCmwtNaYZ2qM+SHwd1U9KSIhwEvAaKA3MFlEevv4PTSavp1i6d+5FXM3HramJKZJ+M37u/n06Bl+PSGV+BYRTscxDvF2180twD5VvdyO6snAPNfjDGCvq0n4RWA+MN77mP5jmqspyYb91pTE+Ld1e4/zykcHyM5M4ms92zkdxzjI20KfxT+L+FeISBQwCnjTNakTUPM+vwWuae6WnS4iOSKSU1xc7GWsxnNHagdim4fZQVnj106VXuR7C7dyTUI0T4xpsn9EGx/xuNCLSDgwDlh0mWFjgXWqemlz190OQbf7PFR1tqqmq2p6QoL/tjGLDAthYnoi7+84StFpa0pi/I+q8sTSPI6fvcBzkwbQPNwaiQQ7b7boRwNbVPXYZcbU3uIvADrX+D0RKPTiNf3SlCHVTUnmf2xNSYz/WfrJEd7d/jmPjOhOv8RYp+MYP+BNoa+57/0rRCQWuAlYVmPyx0A3Eeni+osgC3j7aoL6E2tKYvxV/slSfrxsBxkpbZhxU1en4xg/4VGhd+17HwEsqTFthojMqDHsLmCVqn55P19VrQAeAt4HdgELVXWHL4I7LTszmc9LzvOBNSUxfqKySnlkQS4C/HZif0KskYhxCfVkkKqWAnG1ps2q9ftrwGtull0BrLjqhH7qlp5t6RAbyesbDnFbn/ZOxzGGWX/fR86hL/j9pP50bhPldBzjR+zK2Kt0qSnJR3uOc8CakhiHbSs4xe//8hl3pHbgzjS3J7aZIGaFvh6yBruakmy0Uy2Nc0ovVjBzfi4JMRH87M5+dvWr+Qor9PXQtmUkt/Vpz8KcAmtKYhzzs3d3ceDEOX47sT+xUdZIxHyVFfp6ys5MpqSsnOXWlMQ44INdx5i78TD33XAN13WNdzqO8VNW6Osp85o2XNu2hTUlMY2u+MwFHlu8jV4dWvK9kd2djmP8mBX6ehIRpg5JYmv+KbYXWFMS0zhUlR+8uY0zFyp4LiuNiFC7+tXUzQq9D9w9MJHmYSF2/xvTaOZuPMzfPi3ih6N70r1djNNxjJ+zQu8Dsc3DGJ/WkWVbj1hTEtPg9had5afv7uSGbvF8Y2iK03FME2CF3keyM5M5X17Fm5utKYlpOBcrqpi54BOah4Xwm3v708yufjUesELvI307xZLWuRWvbzxkTUlMg3nug8/IO3KaX9zdj3YtI52OY5oIK/Q+NC0zmf3F51i/74TTUUwA2nTgJC+v3sfE9ERG9e3gdBzThFih96HbUzvQKiqM1+1KWeNjp8+X88iCXJLaRPHjsX2cjmOaGCv0PlTdlKQz7+84xjFrSmJ86KllOzh6+jy/n5RGiwiP7kVozJes0PvYlIwkKquU+ZusKYnxjXe2FrLkkyM8dPO1DExq7XQc0wRZofexlPhobuyeYE1JjE8UnirjiaXbSevciu9+7Vqn45gmygp9A8geksTR0+f56y5rSmKuXlWV8r2FW6moUp6dlEZoiH1dzdWxT04D+FrPtnSMjWSuHZQ19fC/aw+wfv8JfjK2Nynx0U7HMU2YFfoGYE1JTH3tLDzNM+/vZmTvdkxM7+x0HNPEXbHQi0gPEcmt8XNaRGa6GTfcNX+HiPy9xvSDIrLdNS/Hx/n91qSM6qYkc+3+N8ZL58srmbngE2KjwvjlPanWSMTU2xXP01LV3UAagIiEAEeApTXHiEgr4GVglKoeFpG2tZ7mZlU97ovATUXbmEhu69ueRZsL+P5tPYgMs7sLGs/8auWnfHbsLH/+VgZtosOdjmMCgLe7bm4B9qlq7c3UKcASVT0MoKp2FJLqK2VLysp5Z2uh01FME7Hms2L+b91BvnldCjd1T3A6jgkQ3hb6LGCem+ndgdYislpENovI12vMU2CVa/r0up5YRKaLSI6I5BQXF3sZyz8N6dKGbm1b2O2LjUdOnrvI9xZtpVvbFjw+uqfTcUwA8bjQi0g4MA5Y5GZ2KDAIuB24DXhSRC61vBmmqgOB0cCDInKju+dX1dmqmq6q6QkJgbElIyJkZyaztaCEbQWnnI5j/Jiq8qMl2zlVepFns9JsV5/xKW+26EcDW1T1mJt5BcBKVT3n2he/BugPoKqFrn+LqN63n1G/yE3LXQM7WVMSc0WLcgpYueMo3x/Zgz4dY52OYwKMN4V+Mu532wAsA24QkVARiQKGALtEJFpEYgBEJBoYCeTVJ3BT0zIyjDsHdOLtrYWUlFpTEvNVh06c46l3dpB5TRv+/YZrnI5jApBHhd5VvEcAS2pMmyEiMwBUdRewEtgGbAL+pKp5QDtgrYhsdU1/V1VX+vYt+L/szCTOl1exeIs1JTH/qqKyipkLcglpJvxuYhoh1kjENACPboOnqqVAXK1ps2r9/gzwTK1p+3HtwglmfTrGMiCpFXM3HOJbw1LsvGjzpZc+3Mcnh0/x/OQBdGzV3Ok4JkDZlbGNZFpmMvuPn+Mf1pTEuGw5/AXP/20Pd6Z1ZFz/jk7HMQHMCn0jGdOvA62jwuygrAHg3IUKHlmQS/uWkTw9vq/TcUyAs0LfSC41JVm18xhHS6wpSbD7n+U7OXyylN9N7E9s8zCn45gAZ4W+EU0Z4mpK8vFhp6MYB63MO8r8j/OZcVNXhlwTd+UFjKknK/SNKDkumptcTUnKrSlJUCo6fZ4fLtlG304teeTW7ldewBgfsELfyLIzkzl2+gIf7HJ33ZkJZKrK9xdvo6y8kmcnDSA81L5+pnHYJ62Rfa1nWzq1as7rG2z3TbD58z8OsuazYp4Y04tr27ZwOo4JIlboG1lIM2FyRmfW7j3O/uKzTscxjeSzY2f4xXufcnOPBLIzk52OY4KMFXoHTBzcmbAQYe5G26oPBhcqKpk5P5cWEaH8ekJ/u2DONDor9A5oGxPJbX3asygnn7KLlU7HMQ3sd6s+Y+fnp/nlPakkxEQ4HccEISv0DpmWmczp8xW8s82akgSyf+w7zuyP9jM5I4kRvds5HccEKSv0Dsno0obu7awpSSArKS3newu3khIXzZN39HI6jgliVugdcqkpybaCErbmn3I6jmkATy7Lo+jMBZ6dlEZUuEf3DzSmQVihd9BdAzoRFW5NSQLRW58c4e2thcy8pRv9O7dyOo4JclboHRRjTUkCUsEXpTz5Vh6Dkltz//CuTscxxgq907KHJHOhoopFm/OdjmJ8oLJK+c+FW1Hg2UlphIbYV8w4zz6FDuvdsSWDklszd+NhqqrU6Timnv64Zh+bDpzkqXF96Nwmyuk4xgBW6P1CdmYSB6wpSZOXd6SE3636jDH92nPPwE5OxzHmS1cs9CLSQ0Rya/ycFpGZbsYNd83fISJ/rzF9lIjsFpG9IvK4j/MHhNF9O9AmOtwOyjZhZRcreXj+J8S1COfnd/Wzq1+NX7niOV+quhtIAxCREOAIsLTmGBFpBbwMjFLVwyLStsb4l6huLF4AfCwib6vqTh++hyYvMiyEe9MT+dNHBzhacp72sZFORzJe+vmKXewrPsfr3x5Cq6hwp+MY8y+83XVzC7BPVWtvek4BlqjqYQBVLXJNzwD2qup+Vb0IzAfG1ydwoJqakUyVKvM22f1vmpoPPy1izoZDfPv6LlzfLd7pOMZ8hbeFPguY52Z6d6C1iKwWkc0i8nXX9E5AzdNJClzTvkJEpotIjojkFBcXexmr6UuKi7KmJE3Q8bMXeHTxVnq2j+HR23o4HccYtzwu9CISDowDFrmZHQoMAm4HbgOeFJHugLsdlW5PLVHV2aqarqrpCQkJnsYKKNlDkik6c4G/7rSmJE2BqvL4m9s5XVbBs1lpRIaFOB3JGLe82aIfDWxRVXdVqABYqarnVPU4sAbo75reuca4RMDu4lWHm11NSebYQdkmYd6mfP666xiPjepBz/YtnY5jTJ28KfSTcb/bBmAZcIOIhIpIFDAE2AV8DHQTkS6uvwiygLfrEziQhTQTpgxJ4h/7TrC3yJqS+LP9xWf5n+U7uf7aeL41rIvTcYy5LI8Kvat4jwCW1Jg2Q0RmAKjqLmAlsA3YBPxJVfNUtQJ4CHif6sK/UFV3+PYtBJaJ6ZeakthWvb8qr6zikQW5hIc24zf39qdZMzuV0vg3j26pp6qlQFytabNq/f4M8IybZVcAK+qRMagkxEQwqm8H3txcwGO39aR5uO339TfPf7CHrQUlvDx1oJ0Ka5oEuzLWD33ZlGSrHc7wNzkHT/LSh3u5Z2AiY/p1cDqOMR6xQu+HBqe0pnu7FnZQ1s+cOV/OIwtz6dS6OU+N6+10HGM8ZoXeD4kI0zKT2X7EmpL4i7KLlXx/0VaOfFHG7yemERMZ5nQkYzxmhd5P3elqSmJb9c47ePwcd728jlU7j/GjMb1IT2njdCRjvGKF3k/FRIZx14BOvLO1kFOlF52OE7RW7TjK2BfWcvT0ef7vm4P59xuucTqSMV6zQu/HsjOrm5Is3lzgdJSgU1FZxa9Wfsr0OZtJiY/mnYeuZ3iPtk7HMuaqWKH3Y706tCTdmpI0uuNnL/D1Vzfxh9X7mJyRxKIZQ62JiGnSrND7uezMZA4cP8e6fcedjhIUNh86ye3Pf8TmQ1/wzIRUfnF3P7uHjWnyrND7udH92ltTkkagqry27gCT/riBiNAQljxwHfemd77ygsY0AR5dGWucExEawsT0zsxes4/PS8roENvc6UgB59yFCh5fsp13thZya6+2/HZiGrHN7fRJEzhsi74JmDokCaX6bonGt/YWneXOl9bx7rZCHr2tB7OnpVuRNwHHCn0T0LlNFMO7JzDfmpL41IrtnzP+xbWcPHeROd8ewoM3X2s3KDMByQp9EzFtaHVTkr9YU5J6K6+s4qfLd/LA3C10bx/D8v+4nmHXWgtAE7is0DcRN3V3NSVZbwdl66Po9HmmvLKBP609wDevS2HB9KF23MMEPCv0TcSlpiTr959gb9EZp+M0SRv3n2DM82vJO3Ka57LSeGpcH8JD7StgAp99ypuQSYOrm5K8vuGw01GaFFVl9pp9TPnTRlpGhrLsoWGMT3Pbo96YgGSFvgmJbxHB6L4deHNLAaUXK5yO0yScOV/O/a9v4ecrPmVk73Yse2gY3dvFOB3LmEZlhb6JmTY0mTPnK3g715qSXMnuo2cY9+I6/rLrGE+M6cXLUwfa7YVNULpioReRHiKSW+PntIjMrDVmuIiU1Bjz4xrzDorIdtf0nAZ4D0ElPbk1PdrFMGfDIVTt/jd1WZZ7hDtfWsfZCxW88e9DuO/GaxCxUydNcLrilbGquhtIAxCREOAIsNTN0I9U9Y46nuZmVbWbtfiAiJA9NJkn38ojN/8UA5JaOx3Jr1ysqOJn7+7kz+sPkZHShhenDKBtS+vraoKbt7tubgH2qaqd4+eguwZ0Ijo8xA7K1lJ4qoyJf1zPn9cf4r4bujD3viFW5I3B+0KfBcyrY95QEdkqIu+JSJ8a0xVYJSKbRWR6XU8sItNFJEdEcoqLi72MFVxaRIRy18BOLN9WyBfnrCkJwLq9x7njhbXsOXaGl6cO5InbexMWYoegjAEvCr2IhAPjgEVuZm8BklW1P/AC8FaNecNUdSAwGnhQRG509/yqOltV01U1PSEhwdNYQcuaklSrqlJe+nAv0/53I3HR4bz93esZ06+D07GM8SvebPKMBrao6leuwVfV06p61vV4BRAmIvGu3wtd/xZRvW8/o96pDT3bt2RwSmvmbjwUtE1JSkrLmT4nh2fe380dqR1568FhdE1o4XQsY/yON4V+MnXsthGR9uI6pUFEMlzPe0JEokUkxjU9GhgJ5NUvsrkkOzOZgydKWbs3+I5z7ygsYeyLa1m9u5inx/Xhuaw0oiPsrtvGuOPRN0NEooARwHdqTJsBoKqzgAnA/SJSAZQBWaqqItIOWOr6f0Ao8IaqrvTtWwheo/q2J87VlOTG7sGzu2tRTj7/9VYeraPCWfCdoQxKtjOPjLkcjwq9qpYCcbWmzarx+EXgRTfL7Qf61zOjqUNEaAgTB3fmj3/fR+GpMjq2Cuybc50vr+Tpd3Ywb1M+13WN4/nJA4hvEeF0LGP8np2W0MRNyahuSjJ/U2Cfapl/spR7Z61n3qZ8HhjelTnfHmJF3hgPWaFv4jq3ieLmHm2Z93F+wDYl+XB3EXe8sJaDJ87xytfTeWxUT0KsQYgxHrNCHwCmZSZTfOYCq3YEVlOSyirl93/5jG+99jEdWzVn+XevZ0Tvdk7HMqbJsUIfAG7snkBi6+bM2XDQ6Sg+88W5i/zbax/z3Ad7uHtAIkvuv47kuGinYxnTJFmhDwAhzYSpQ5LZsP9kQDQl2Zp/ijteWMuGfSf4+V39+M29qTQPD3E6ljFNlhX6ADExPZHwkGZN+v43qsrcjYe4d9Z6ABbfP5QpQ5LsrpPG1JMV+gAR1yKCMf3a8+bmptmUpOxiJd9btJUnluaR2TWO5d+9ntTEVk7HMiYgWKEPINmZyZy5UMGyJtaU5ODxc9z18jqWfnKEh2/pxv99czCto8OdjmVMwLBCH0AGJbemZ/sY5qxvOk1JVu04ytgX13L09Hle/eZgHhnR3U6dNMbHrNAHEBEhOzOZnZ+f5pP8U07HuayKyip+tfJTps/ZTEpcNO88dD0392jrdCxjApIV+gBz55dNSfy3N8zxsxf4+qub+MPqfUzOSGLRjKF0bhPldCxjApYV+gDTIiKUuwcmsnzb537ZlGTzoZPc8fxaNh/6gmcmpPKLu/sRGWanThrTkKzQB6DszGQuVlSxaHO+01G+pKq8tu4Ak/64gfDQZix54DruTe/sdCxjgoIV+gDUo30MGSltmLvxsF80JTl3oYKH5+fy1Ds7Gd4jgXe+ez19OsY6HcuYoGGFPkBNzUzi0IlSPnK4KcneorPc+dI6lm8r5NHbejB7WjqxzcMczWRMsLFCH6BG9W1PfItwRw/Krtj+OeNfXMuJcxf5f98awoM3X0szO3XSmEZnhT5ARYSGMDG9Mx/sOsaRU2WN+trllVX8dPlOHpi7he7tY3j3P67n+m7xjZrBGPNPVugD2JQhjd+UpOj0eaa+spE/rT3AN4Yms2D6UDrEBnbnK2P83RULvYj0EJHcGj+nRWRmrTHDRaSkxpgf15g3SkR2i8heEXm8Ad6DqUNi6yi+1qMt8zblc7Gi4ZuSbNx/gjHPr2X7kRKey0rj6fF9CQ+1bQljnHbFb6Gq7lbVNFVNAwYBpcBSN0M/ujROVf8bQERCgJeA0UBvYLKI9PZZenNF2UOTOX72Aqt2Hm2w11BVZq/Zx5Q/baRlZChvPTiM8WmdGuz1jDHe8XZz6xZgn6p6eoQvA9irqvtV9SIwHxjv5WuaeripWwKd2zRnzvqGOSh75nw5D8zdws9XfMqIXu1Y9tAwerSPaZDXMsZcHW8LfRYwr455Q0Vkq4i8JyJ9XNM6ATWv2ilwTfsKEZkuIjkiklNcXOxlLFOXZq6mJBsPnGTPMd82Jdl99AzjX1zHqp3HeGJML/6QPZCYSDt10hh/43GhF5FwYBywyM3sLUCyqvYHXgDeurSYm7Fur+BR1dmqmq6q6QkJCZ7GMh64d9ClpiS+26pflnuEO19ax5kLFbzx70O478ZrrEGIMX7Kmy360cAWVf1KB2pVPa2qZ12PVwBhIhJP9RZ8zevcE4GmdbP0ABDXIoLbUzuwZMsRzl2oX1OSixVV/GRZHg/Pz6Vvp5a8+93rGXJNnI+SGmMagjeFfjJ17LYRkfbi2pwTkQzX854APga6iUgX118EWcDb9YtsrkZ2ZlK9m5J8XlLGpNnr+fP6Q9x3QxfeuC+Tti0jfZjSGNMQQj0ZJCJRwAjgOzWmzQBQ1VnABOB+EakAyoAsre58USEiDwHvAyHAq6q6w7dvwXhiYFJrenVoyesbDjE5o7PXu1nW7T3Od+d9woXySl6eOpAx/To0UFJjjK95VOhVtRSIqzVtVo3HLwIv1rHsCmBFPTIaH6huSpLEE0vz2HL4FIOSW3u0XFWV8oe/7+O3q3bTNaEFs6YNomtCiwZOa4zxJbuaJYjcmdaJFhGhzPXwoGxJWTnT5+TwzPu7uT21I289OMyKvDFNkBX6IBIdEcrdAzuxfNvnnLxCU5IdhSWMfWEtq3cX89TY3jyflUZ0hEd/ABpj/IwV+iCTnZnMxcoqFuXU3ZRkUU4+d7/8Dy5WVLHgO0P55rAuduqkMU2YFfog071dDBld3DclOV9eyQ+XbOPRxdsYmNSa5f9xvcf78o0x/ssKfRDKzkzm8MlS1uz55xXI+SdLuXfWeuZtyueB4V2Z8+0M4ltEOJjSGOMrttM1CI3q88+mJMN7tOXD3UXMnJ9LlSqzpw1iZJ/2Tkc0xviQFfogFB7ajEmDO/OH1ft46u0d/Hn9QXq0i2FW9iBS4qOdjmeM8THbdROkJmckAfDaPw5y14BOLH1gmBV5YwKUbdEHqcTWUfxkbB9auE65tLNqjAlcVuiD2DeuS3E6gjGmEdiuG2OMCXBW6I0xJsBZoTfGmABnhd4YYwKcFXpjjAlwVuiNMSbAWaE3xpgAZ4XeGGMC3BULvYj0EJHcGj+nRWRmHWMHi0iliEyoMe2giGx3LZvjw+zGGGM8cMUrY1V1N5AGICIhwBFgae1xrnm/oroReG03q+rxeiU1xhhzVbzddXMLsE9V3TUd/S7wJlBU71TGGGN8xttCnwXMqz1RRDoBdwGz3CyjwCoR2Swi0+t6YhGZLiI5IpJTXFxc1zBjjDFe8rjQi0g4MA5Y5Gb2s8APVLXSzbxhqjoQGA08KCI3unt+VZ2tqumqmp6QkOBpLGOMMVfgzd0rRwNbVPWYm3npwHzXrW7jgTEiUqGqb6lqIYCqFonIUiADWFPP3MYYYzzkTaGfjJvdNgCq2uXSYxF5DViuqm+JSDTQTFXPuB6PBP67HnmNMcZ4yaNCLyJRwAjgOzWmzQBQVXf75S9pByx1bemHAm+o6sqrTmuMMcZrHhV6VS0F4mpNc1vgVfWbNR7vB/rXI58xxph6sitjjTEmwFmhN8aYAGeF3hhjApwVemOMCXCiqk5n+AoRKQbc3WbBE/GAP95Xx3J5x3J5x3J5JxBzJauq26tN/bLQ14eI5KhqutM5arNc3rFc3rFc3gm2XLbrxhhjApwVemOMCXCBWOhnOx2gDpbLO5bLO5bLO0GVK+D20RtjjPlXgbhFb4wxpgYr9MYYE+CaZKEXkVdFpEhE8uqYLyLyvIjsFZFtIjLQT3INF5GSGo3Wf9xIuTqLyIcisktEdojIw27GNPo68zBXo68zEYkUkU0istWV62k3Y5xYX57kcuQz5nrtEBH5RESWu5nnyHfSg1xOfScPish212vmuJnv2/Wlqk3uB7gRGAjk1TF/DPAeIEAmsNFPcg2n+l79jb2+OgADXY9jgM+A3k6vMw9zNfo6c62DFq7HYcBGINMP1pcnuRz5jLle+z+BN9y9vlPfSQ9yOfWdPAjEX2a+T9dXk9yiV9U1wMnLDBkP/D+ttgFoJSId/CCXI1T1c1Xd4np8BtgFdKo1rNHXmYe5Gp1rHZx1/Rrm+ql91oIT68uTXI4QkUTgduBPdQxx5DvpQS5/5dP11SQLvQc6Afk1fi/ADwqIy1DXn97viUifxn5xEUkBBlC9NViTo+vsMrnAgXXm+nM/FygC/qKqfrG+PMgFznzGngUeA6rqmO/U5+tZLp8LnFlfCqwSkc0iMt3NfJ+ur0At9OJmmj9s+Wyh+n4U/YEXgLca88VFpAXwJjBTVU/Xnu1mkUZZZ1fI5cg6U9VKVU0DEoEMEelba4gj68uDXI2+vkTkDqBIVTdfbpibaQ26vjzM5dR3cpiqDqS6F/eDInJjrfk+XV+BWugLgM41fk8ECh3K8iVVPX3pT29VXQGEiUh8Y7y2iIRRXUznquoSN0McWWdXyuXkOnO95ilgNTCq1ixHP2N15XJofQ0DxonIQWA+8DUReb3WGCfW1xVzOfX5UtVC179FwFIgo9YQn66vQC30bwNfdx25zgRKVPVzp0OJSHuR6ga6IpJB9fo/0QivK8D/ArtU9Xd1DGv0deZJLifWmYgkiEgr1+PmwK3Ap7WGObG+rpjLifWlqj9U1URVTQGygL+panatYY2+vjzJ5dDnK1pEYi49BkYCtc/U8+n68qhnrL8RkXlUHy2PF5EC4CdUH5hCq3vZrqD6qPVeoBT4Nz/JNQG4X0QqgDIgS12H2BvYMGAasN21fxfgR0BSjWxOrDNPcjmxzjoAfxaREKq/+AtVdbmIzKiRy4n15Ukupz5jX+EH68uTXE6sr3bAUtf/X0KBN1R1ZUOuL7sFgjHGBLhA3XVjjDHGxQq9McYEOCv0xhgT4KzQG2NMgLNCb4wxAc4KvTHGBDgr9MYYE+D+P0Oz6KL0PJAIAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "learn.recorder.plot_loss(skip_start=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.validate\" class=\"doc_header\"><code>Learner.validate</code><a href=\"__main__.py#L138\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.validate</code>(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`cbs`**=*`None`*)\n",
       "\n",
       "Validate on `dl` with potential new `cbs`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.validate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test result\n",
    "learn = synth_learner(n_train=5, metrics=tst_metric)\n",
    "res = learn.validate()\n",
    "test_eq(res[0], res[1])\n",
    "x,y = learn.dls.valid_ds.tensors\n",
    "test_close(res[0], F.mse_loss(learn.model(x), y), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test other dl\n",
    "res = learn.validate(dl=learn.dls.train)\n",
    "test_eq(res[0], res[1])\n",
    "x,y = learn.dls.train_ds.tensors\n",
    "test_close(res[0], F.mse_loss(learn.model(x), y), 1e-3)\n",
    "\n",
    "#Test additional callback is executed.\n",
    "cycle = cycle_events[:2] + ['before_validate'] + batchv_events * 2 + cycle_events[-3:]\n",
    "test_stdout(lambda: learn.validate(cbs=VerboseCallback()), '\\n'.join(cycle))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.get_preds\" class=\"doc_header\"><code>Learner.get_preds</code><a href=\"__main__.py#L143\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.get_preds</code>(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`with_input`**=*`False`*, **`with_decoded`**=*`False`*, **`with_loss`**=*`False`*, **`act`**=*`None`*, **`inner`**=*`False`*, **`reorder`**=*`True`*, **`cbs`**=*`None`*, **`save_preds`**=*`None`*, **`save_targs`**=*`None`*, **`concat_dim`**=*`0`*)\n",
       "\n",
       "Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.get_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`with_decoded` will also return the decoded predictions using the <code>decodes</code> function of the loss function (if it exists). For instance, fastai's `CrossEntropyFlat` takes the argmax or predictions in its decodes. \n",
    "\n",
    "Depending on the `loss_func` attribute of `Learner`, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with `act`.\n",
    "\n",
    "`save_preds` and `save_targs` should be used when your predictions are too big to fit all in memory. Give a `Path` object that points to a folder where the predictions and targets will be saved.\n",
    "\n",
    "`concat_dim` is the batch dimension, where all the tensors will be concatenated.\n",
    "\n",
    "`inner` is an internal attribute that tells `get_preds` it's called internally, inside another training loop, to avoid recursion errors."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: If you want to use the option `with_loss=True` on a custom loss function, make sure you have implemented a `reduction` attribute that supports 'none' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test result\n",
    "learn = synth_learner(n_train=5, metrics=tst_metric)\n",
    "preds,targs = learn.get_preds()\n",
    "x,y = learn.dls.valid_ds.tensors\n",
    "test_eq(targs, y)\n",
    "test_close(preds, learn.model(x))\n",
    "\n",
    "preds,targs = learn.get_preds(act = torch.sigmoid)\n",
    "test_eq(targs, y)\n",
    "test_close(preds, torch.sigmoid(learn.model(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test get_preds work with ds not evenly divisible by bs\n",
    "learn = synth_learner(n_train=2.5, metrics=tst_metric)\n",
    "preds,targs = learn.get_preds(ds_idx=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test other dataset\n",
    "x = torch.randn(16*5)\n",
    "y = 2*x + 3 + 0.1*torch.randn(16*5)\n",
    "dl = TfmdDL(TensorDataset(x, y), bs=16)\n",
    "preds,targs = learn.get_preds(dl=dl)\n",
    "test_eq(targs, y)\n",
    "test_close(preds, learn.model(x))\n",
    "\n",
    "#Test with loss\n",
    "preds,targs,losses = learn.get_preds(dl=dl, with_loss=True)\n",
    "test_eq(targs, y)\n",
    "test_close(preds, learn.model(x))\n",
    "test_close(losses, F.mse_loss(preds, targs, reduction='none'))\n",
    "\n",
    "#Test with inputs\n",
    "inps,preds,targs = learn.get_preds(dl=dl, with_input=True)\n",
    "test_eq(inps,x)\n",
    "test_eq(targs, y)\n",
    "test_close(preds, learn.model(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with no target\n",
    "learn = synth_learner(n_train=5)\n",
    "x = torch.randn(16*5)\n",
    "dl = TfmdDL(TensorDataset(x), bs=16)\n",
    "preds,targs = learn.get_preds(dl=dl)\n",
    "assert targs is None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with targets that are tuples\n",
    "def _fake_loss(x,y,z,reduction=None): return F.mse_loss(x,y)\n",
    "\n",
    "learn = synth_learner(n_train=5)\n",
    "x = torch.randn(16*5)\n",
    "y = 2*x + 3 + 0.1*torch.randn(16*5)\n",
    "learn.dls.n_inp=1\n",
    "learn.loss_func = _fake_loss\n",
    "dl = TfmdDL(TensorDataset(x, y, y), bs=16)\n",
    "preds,targs = learn.get_preds(dl=dl)\n",
    "test_eq(targs, [y,y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with inputs that are tuples\n",
    "class _TupleModel(Module):\n",
    "    def __init__(self, model): self.model=model\n",
    "    def forward(self, x1, x2): return self.model(x1)\n",
    "\n",
    "learn = synth_learner(n_train=5)\n",
    "#learn.dls.n_inp=2\n",
    "x = torch.randn(16*5)\n",
    "y = 2*x + 3 + 0.1*torch.randn(16*5)\n",
    "learn.model = _TupleModel(learn.model)\n",
    "learn.dls = DataLoaders(TfmdDL(TensorDataset(x, x, y), bs=16),TfmdDL(TensorDataset(x, x, y), bs=16))\n",
    "inps,preds,targs = learn.get_preds(ds_idx=0, with_input=True)\n",
    "test_eq(inps, [x,x])\n",
    "t = learn.get_preds(ds_idx=0, with_input=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test auto activation function is picked\n",
    "learn = synth_learner(n_train=5)\n",
    "learn.loss_func = BCEWithLogitsLossFlat()\n",
    "x = torch.randn(16*5)\n",
    "y = 2*x + 3 + 0.1*torch.randn(16*5)\n",
    "dl = TfmdDL(TensorDataset(x, y), bs=16)\n",
    "preds,targs = learn.get_preds(dl=dl)\n",
    "test_close(preds, torch.sigmoid(learn.model(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test reorder is done\n",
    "learn = synth_learner(n_train=5)\n",
    "x = torch.randn(16*5)\n",
    "y = 2*x + 3 + 0.1*torch.randn(16*5)\n",
    "dl = TfmdDL(TensorDataset(x, y), bs=16, shuffle=True)\n",
    "preds,targs = learn.get_preds(dl=dl)\n",
    "test_eq(targs, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "inps,preds,targs = learn.get_preds(ds_idx=0, with_input=True)\n",
    "tst = learn.get_preds(ds_idx=0, with_input=True, with_decoded=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.predict\" class=\"doc_header\"><code>Learner.predict</code><a href=\"__main__.py#L169\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.predict</code>(**`item`**, **`rm_type_tfms`**=*`None`*, **`with_input`**=*`False`*)\n",
       "\n",
       "Prediction on `item`, fully decoded, loss function decoded and probabilities"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It returns a tuple of three elements with, in reverse order,\n",
    "- the prediction from the model, potentially passed through the activation of the loss function (if it has one)\n",
    "- the decoded prediction, using the potential <code>decodes</code> method from it\n",
    "- the fully decoded prediction, using the transforms used to build the `Datasets`/`DataLoaders`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`rm_type_tfms` is a deprecated argument that should not be used and will be removed in a future version. `with_input` will add the decoded inputs to the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _FakeLossFunc(Module):\n",
    "    reduction = 'none'\n",
    "    def forward(self, x, y): return F.mse_loss(x,y)\n",
    "    def activation(self, x): return x+1\n",
    "    def decodes(self, x):    return 2*x\n",
    "\n",
    "class _Add1(Transform):\n",
    "    def encodes(self, x): return x+1\n",
    "    def decodes(self, x): return x-1\n",
    "    \n",
    "learn = synth_learner(n_train=5)\n",
    "dl = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]]))\n",
    "learn.dls = DataLoaders(dl, dl)\n",
    "learn.loss_func = _FakeLossFunc()\n",
    "\n",
    "inp = tensor([2.])\n",
    "out = learn.model(inp).detach()+1  #applying model + activation\n",
    "dec = 2*out                        #decodes from loss function\n",
    "full_dec = dec-1                   #decodes from _Add1\n",
    "test_eq(learn.predict(inp), [full_dec,dec,out])\n",
    "test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.show_results\" class=\"doc_header\"><code>Learner.show_results</code><a href=\"__main__.py#L180\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.show_results</code>(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`max_n`**=*`9`*, **`shuffle`**=*`True`*, **\\*\\*`kwargs`**)\n",
       "\n",
       "Show some predictions on `ds_idx`-th dataset or `dl`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.show_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Will show `max_n` samples (unless the batch size of `ds_idx` or `dl` is less than `max_n`, in which case it will show as many samples) and `shuffle` the data unless you pass `false` to that flag. `kwargs` are application-dependent.\n",
    "\n",
    "We can't show an example on our synthetic `Learner`, but check all the beginners tutorials which will show you how that method works across applications."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last functions in this section are used internally for inference, but should be less useful to you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.no_logging\" class=\"doc_header\"><code>Learner.no_logging</code><a href=\"__main__.py#L193\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.no_logging</code>()\n",
       "\n",
       "Context manager to temporarily remove `logger`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.no_logging)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner(n_train=5, metrics=tst_metric)\n",
    "with learn.no_logging():\n",
    "    test_stdout(lambda: learn.fit(1), '')\n",
    "test_eq(learn.logger, print)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Learner.loss_not_reduced\" class=\"doc_header\"><code>Learner.loss_not_reduced</code><a href=\"__main__.py#L198\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Learner.loss_not_reduced</code>()\n",
       "\n",
       "A context manager to evaluate `loss_func` with reduction set to none."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Learner.loss_not_reduced)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This requires your loss function to either have a `reduction` attribute or a `reduction` argument (like all fastai and PyTorch loss functions)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(learn.loss_func.reduction, 'mean')\n",
    "with learn.loss_not_reduced():\n",
    "    test_eq(learn.loss_func.reduction, 'none')\n",
    "    x,y = learn.dls.one_batch()\n",
    "    p = learn.model(x)\n",
    "    losses = learn.loss_func(p, y)\n",
    "    test_eq(losses.shape, y.shape)\n",
    "    test_eq(losses, F.mse_loss(p,y, reduction='none'))\n",
    "test_eq(learn.loss_func.reduction, 'mean')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transfer learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def freeze_to(self:Learner, n):\n",
    "    if self.opt is None: self.create_opt()\n",
    "    self.opt.freeze_to(n)\n",
    "    self.opt.clear_state()\n",
    "\n",
    "@patch\n",
    "def freeze(self:Learner): self.freeze_to(-1)\n",
    "\n",
    "@patch\n",
    "def unfreeze(self:Learner): self.freeze_to(0)\n",
    "\n",
    "add_docs(Learner,\n",
    "         freeze_to=\"Freeze parameter groups up to `n`\",\n",
    "         freeze=\"Freeze up to last parameter group\",\n",
    "         unfreeze=\"Unfreeze the entire model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 29.69741439819336, 28.89706039428711, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "class _TstModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n",
    "        self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))\n",
    "        self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3) \n",
    "    def forward(self, x): return x * self.a + self.b\n",
    "    \n",
    "class _PutGrad(Callback):\n",
    "    def after_backward(self):\n",
    "        for p in self.learn.model.tst.parameters():\n",
    "            if p.requires_grad: p.grad = torch.ones_like(p.data)\n",
    "\n",
    "def _splitter(m): return [list(m.tst[0].parameters()), list(m.tst[1].parameters()), [m.a,m.b]]\n",
    "            \n",
    "learn = synth_learner(n_train=5, opt_func = partial(SGD), cbs=_PutGrad, splitter=_splitter, lr=1e-2)\n",
    "learn.model = _TstModel()\n",
    "learn.freeze()\n",
    "init = [p.clone() for p in learn.model.tst.parameters()]\n",
    "learn.fit(1, wd=0.)\n",
    "end = list(learn.model.tst.parameters())\n",
    "#linear was not trained\n",
    "for i in [0,1]: test_close(end[i],init[i])\n",
    "#bn was trained even frozen since `train_bn=True` by default\n",
    "for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 16.035701751708984, 16.854097366333008, '00:00']\n",
      "[0, 13.14907169342041, 13.846001625061035, '00:00']\n",
      "[0, 10.72517204284668, 11.376962661743164, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "learn = synth_learner(n_train=5, opt_func = partial(SGD), cbs=_PutGrad, splitter=_splitter, train_bn=False, lr=1e-2)\n",
    "learn.model = _TstModel()\n",
    "learn.freeze()\n",
    "init = [p.clone() for p in learn.model.tst.parameters()]\n",
    "learn.fit(1, wd=0.)\n",
    "end = list(learn.model.tst.parameters())\n",
    "#linear and bn were not trained\n",
    "for i in range(4): test_close(end[i],init[i])\n",
    "\n",
    "learn.freeze_to(-2)\n",
    "init = [p.clone() for p in learn.model.tst.parameters()]\n",
    "learn.fit(1, wd=0.)\n",
    "end = list(learn.model.tst.parameters())\n",
    "#linear was not trained\n",
    "for i in [0,1]: test_close(end[i],init[i])\n",
    "#bn was trained \n",
    "for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))\n",
    "    \n",
    "learn.unfreeze()\n",
    "init = [p.clone() for p in learn.model.tst.parameters()]\n",
    "learn.fit(1, wd=0.)\n",
    "end = list(learn.model.tst.parameters())\n",
    "#linear and bn were trained\n",
    "for i in range(4): test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]), 1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exporting a `Learner`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def export(self:Learner, fname='export.pkl', pickle_protocol=2):\n",
    "    \"Export the content of `self` without the items and the optimizer state for inference\"\n",
    "    if rank_distrib(): return # don't export if child proc\n",
    "    self._end_cleanup()\n",
    "    old_dbunch = self.dls\n",
    "    self.dls = self.dls.new_empty()\n",
    "    state = self.opt.state_dict() if self.opt is not None else None\n",
    "    self.opt = None\n",
    "    with warnings.catch_warnings():\n",
    "        #To avoid the warning that come from PyTorch about model not being checked\n",
    "        warnings.simplefilter(\"ignore\")\n",
    "        torch.save(self, self.path/fname, pickle_protocol=pickle_protocol)\n",
    "    self.create_opt()\n",
    "    if state is not None: self.opt.load_state_dict(state)\n",
    "    self.dls = old_dbunch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Learner` is saved in `self.path/fname`, using `pickle_protocol`. Note that serialization in Python saves the names of functions, not the code itself. Therefore, any custom code you have for models, data transformation, loss function etc... should be put in a module that you will import in your training environment before exporting, and in your deployment environment before loading it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def load_learner(fname, cpu=True):\n",
    "    \"Load a `Learner` object in `fname`, optionally putting it on the `cpu`\"\n",
    "    distrib_barrier()\n",
    "    res = torch.load(fname, map_location='cpu' if cpu else None)\n",
    "    if hasattr(res, 'to_fp32'): res = res.to_fp32()\n",
    "    if cpu: res.dls.cpu()\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Warning: `load_learner` requires all your custom code be in the exact same place as when exporting your `Learner` (the main script, or the module you imported it from)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TTA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):\n",
    "    \"Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation\"\n",
    "    if dl is None: dl = self.dls[ds_idx]\n",
    "    if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)\n",
    "    try:\n",
    "        self(_before_epoch)\n",
    "        with dl.dataset.set_split_idx(0), self.no_mbar():\n",
    "            if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))\n",
    "            aug_preds = []\n",
    "            for i in self.progress.mbar if hasattr(self,'progress') else range(n):\n",
    "                self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch\n",
    "                aug_preds.append(self.get_preds(dl=dl, inner=True)[0][None])\n",
    "        aug_preds = torch.cat(aug_preds)\n",
    "        aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)\n",
    "        self.epoch = n\n",
    "        with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True)\n",
    "    finally: self(event.after_fit)\n",
    "\n",
    "    if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs\n",
    "    preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)\n",
    "    return preds,targs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In practice, we get the predictions `n` times with the transforms of the training set and average those. The final predictions are `(1-beta)` multiplied by this average + `beta` multiplied by the predictions obtained with the transforms of the dataset. Set `beta` to `None` to get a tuple of the predictions and tta results. You can also use the maximum of all predictions instead of an average by setting `use_max=True`.\n",
    "\n",
    "If you want to use new transforms, you can pass them with `item_tfms` and `batch_tfms`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn = synth_learner()\n",
    "dl = TfmdDL(Datasets(torch.arange(50), [noop,noop]))\n",
    "learn.dls = DataLoaders(dl, dl)\n",
    "preds,targs = learn.tta()\n",
    "assert len(preds),len(targs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
